class StackViewer(object):
    """
    Parameters
    ----------
    viewer : object
        expected to have update_image method and fig attribute
    images : array-like
        must support integer indexing and return a 2D array
    """
    def __init__(self, viewer, images):
        self.viewer = viewer
        self.images = images
        length = len(self.images)
        fig = self.viewer._fig
        slider_ax = fig.add_axes([0.1, 0.01, 0.8, 0.02])
        self.slider = Slider(slider_ax,
                             'Frame',
                             0,
                             length - 1,
                             0,
                             valfmt='%d/{}'.format(length - 1))
        self.slider.on_changed(self.update)
        self.update(0)  # Trigger the initialization of viewer.

    def update(self, val):
        if not isinstance(val, int):
            self.slider.set_val(int(round(val)))
            # sends up through 'update' again
        self.viewer.update_image(self.images[int(val)])
Exemple #2
0
def generate_plots():
    global rts_filename, img3d, contour_matrix, im, structure_wanted, list_of_ct_filenames, current_slice_slider,contour_plot   

    # get all the relevant data (choose random structure set, get ct and contour data from dicoms)

    structure_wanted, rts_filename, list_of_ct_filenames = get_random_structure_set()

    img3d, contour_matrix, x_origin, y_origin, pixel_size = prepare_dicom_data_for_individual_case(rts_filename, list_of_ct_filenames, structure_wanted)

    #smooth the human contours if desired

    if smooth_human_contours == True and human_rts_filename in rts_filename:
        smooth_contour_matrix()
    
    # can remove ax.clear() for randomly generated colors of contours
    ax.clear()
    
    # deal with the Slider
    # try statement only runs if it's NOT the first image plotted (ie slider doesn't exist yet, needs to be initialized)
    # subsequent images just update the slider rather than re-creating it
    try:
        current_slice_slider.set_val(int(img3d.shape[0]/2))
        current_slice_slider.valmax = img3d.shape[0]-1
        current_slice_slider.ax.set_xlim(current_slice_slider.valmin, current_slice_slider.valmax)
    except NameError:
        current_slice_slider = Slider(slice_slider_axes, 'CT Slice', 0, img3d.shape[0]-1, valstep=1, valfmt='%0.0f')
        current_slice_slider.set_val(int(img3d.shape[0]/2))
        current_slice_slider.on_changed(on_slider_change)
    
    #plot the data, and connect
    im = ax.imshow(img3d[int(current_slice_slider.val), :, :], 
                         extent=[x_origin, x_origin + img3d.shape[1] * pixel_size, y_origin, y_origin + img3d.shape[2] * pixel_size],
                         cmap='Greys_r', vmin=hu_min, vmax=hu_max, animated = True, interpolation = 'nearest', origin = 'upper')
    contour_plot = ax.plot(contour_matrix[int(current_slice_slider.val)][0], contour_matrix[int(current_slice_slider.val)][1] )  
Exemple #3
0
def main(filenames, equal_axes):
    all_data = []
    tmin = tmax = 0
    xmin, ymin, xmax, ymax = 1e100, 1e100, -1e100, -1e100
    for i, filename in enumerate(filenames):
        name = str(i + 1)
        if ':' in filename:
            filename, name = filename.split(':')
        print('Reading %s from file %s' % (name, filename))
        _description, _value, _dim, timesteps, data = read_iso_surfaces(
            filename)
        all_data.append((name, numpy.array(timesteps), data))

        tmax = max(tmax, timesteps[-1])
        for contours in data:
            for contour in contours:
                xmin = min(xmin, numpy.min(contour[0]))
                ymin = min(ymin, numpy.min(contour[1]))
                xmax = max(xmax, numpy.max(contour[0]))
                ymax = max(ymax, numpy.max(contour[1]))

    fig, ax = pyplot.subplots()
    pyplot.subplots_adjust(bottom=0.25)
    axcolor = '#a1b8dd'
    slider_ax = pyplot.axes([0.1, 0.1, 0.8, 0.03], facecolor=axcolor)
    slider = Slider(slider_ax, 'Time', tmin, tmax, valinit=tmin)

    xdiff = xmax - xmin
    ydiff = ymax - ymin
    xmin, xmax = xmin - 0.1 * xdiff, xmax + 0.1 * xdiff
    ymin, ymax = ymin - 0.1 * ydiff, ymax + 0.1 * ydiff
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    def update(val):
        xmin, xmax = ax.get_xlim()
        ymin, ymax = ax.get_ylim()
        ax.clear()

        t = slider.val
        for name, timesteps, data in all_data:
            i = numpy.argmin(abs(timesteps - t))
            dt = timesteps[1] - timesteps[0]
            if abs(timesteps[i] - t) > 1.5 * dt:
                continue
            plotit(ax, data[i], name)

        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

        if equal_axes:
            ax.set_aspect('equal')

        if len(filenames) > 1:
            ax.legend(loc='lower right')
        fig.canvas.draw_idle()

    slider.on_changed(update)
    slider.set_val(tmin)
    pyplot.show()
class ImgView(object):
    def __init__(self, img, fig_id=None, imin=None, imax=None):
        self.fig = plt.figure(fig_id)
        self.ax = self.fig.add_axes([0.1, 0.25, 0.8, 0.7])
        self.fig.subplots_adjust(bottom=0.25)

        axcolor = 'lightgoldenrodyellow'
        self.axlo = self.fig.add_axes([0.15, 0.1, 0.65, 0.03], axisbg=axcolor)
        self.axhi = self.fig.add_axes([0.15, 0.15, 0.65, 0.03], axisbg=axcolor)
        self.axrefr = self.fig.add_axes([0.15, 0.05, 0.10, 0.03],
                                        axisbg=axcolor)

        imsort = np.sort(img.flatten())
        n = len(imsort)
        if imin is None:
            imin = imsort[int(n * 0.005)]
        if imax is None:
            imax = imsort[int(n * 0.995)]
        self.slo = Slider(self.axlo,
                          'Scale min',
                          imin,
                          imax,
                          valinit=imsort[int(n * 0.02)])
        self.shi = Slider(self.axhi,
                          'Scale max',
                          imin,
                          imax,
                          valinit=imsort[int(n * 0.98)])
        self.brefr = Button(self.axrefr, 'Refresh')
        self.slo.on_changed(self.update_slider)
        self.shi.on_changed(self.update_slider)
        self.brefr.on_clicked(self.refresh)
        self.set_img(img)

    def update_slider(self, val):
        if self.shi.val <= self.slo.val:
            self.shi.set_val(self.slo.val + 1)
        self.imgplt.set_clim(self.slo.val, self.shi.val)
        self.fig.canvas.draw()

    def set_img(self, img=None):
        if img is not None:
            self.img = img
        self.ax.set_xlim(0, self.img.shape[1] - 1)
        self.ax.set_ylim(0, self.img.shape[0] - 1)
        self.imgplt = self.ax.imshow(self.img,
                                     vmin=self.slo.val,
                                     vmax=self.shi.val,
                                     cmap='hot',
                                     origin='lower',
                                     interpolation='nearest')
        self.imgplt.set_cmap('hot')
        divider = make_axes_locatable(self.ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        self.fig.colorbar(self.imgplt, cax=cax)
        self.fig.canvas.draw()

    def refresh(self, event=None):
        self.set_img()
Exemple #5
0
class GUI:
    def __init__(self, coinToss):
        self._coinToss = coinToss
        maxTosses = coinToss.tosses

        axcolor = 'lightgoldenrodyellow'

        self.figure = pyplot.figure()
        self.mainAxis = pyplot.axes([0.05, 0.2, 0.9, 0.75])

        # Slider for number of tosses
        tossAxis = pyplot.axes([0.1, 0.05, 0.8, 0.05])
        self.tossSlider = Slider(tossAxis,
                                 'Tosses',
                                 0.,
                                 1. * maxTosses,
                                 valinit=0,
                                 valfmt=u'%d')
        self.tossSlider.on_changed(lambda x: self.draw(x))

        # Reset button
        resetAxis = pyplot.axes([0.8, 0.85, 0.1, 0.05])
        self.resetButton = Button(resetAxis,
                                  'Re-toss',
                                  color=axcolor,
                                  hovercolor='0.975')
        self.resetButton.on_clicked(self.retoss)

        # Key press events
        self.figure.canvas.mpl_connect('key_press_event',
                                       lambda x: self.press(x))

        self._coinToss.doTosses()

    def retoss(self, event):
        self._coinToss.doTosses()
        self.draw(self.tossSlider.val)

    def press(self, event):
        if event.key == u'left' and self.tossSlider.val > self.tossSlider.valmin:
            self.tossSlider.set_val(self.tossSlider.val - 1)
        if event.key == u'right' and self.tossSlider.val < self.tossSlider.valmax:
            self.tossSlider.set_val(self.tossSlider.val + 1)
        if event.key == u'r':
            self.retoss(event)

    def draw(self, x=0):
        c = self._coinToss
        m = max(enumerate(c.posterior[int(x), :]), key=operator.itemgetter(1))
        self.mainAxis.clear()
        self.mainAxis.plot(c.conditional,
                           c.posterior[int(x), :],
                           lw=2,
                           color='red')
        self.mainAxis.vlines(c.conditional[m[0]], 0, c.posterior.max())
        self.figure.canvas.draw()
Exemple #6
0
    class Index(object):
        def __init__(self, ax_slider, ax_prev, ax_next):
            self.ind = 0
            self.num = len(wavelengths)
            self.bnext = Button(ax_next, 'Next')
            self.bnext.on_clicked(self.next)
            self.bprev = Button(ax_prev, 'Previous')
            self.bprev.on_clicked(self.prev)
            self.slider = Slider(ax_slider,
                                 "Energy Resolution: {:.2f} nm".format(
                                     wavelengths[0]),
                                 0,
                                 self.num,
                                 valinit=0,
                                 valfmt='%d')
            self.slider.valtext.set_visible(False)
            self.slider.label.set_horizontalalignment('center')
            self.slider.on_changed(self.update)

            position = ax_slider.get_position()
            self.slider.label.set_position((0.5, -0.5))
            self.slider.valtext.set_position((0.5, -0.5))

        def next(self, event):
            i = (self.ind + 1) % (self.num + 1)
            self.slider.set_val(i)

        def prev(self, event):
            i = (self.ind - 1) % (self.num + 1)
            self.slider.set_val(i)

        def update(self, i):
            self.ind = int(i)
            image.set_data(R[self.ind])
            if self.ind != len(wavelengths):
                self.slider.label.set_text(
                    "Energy Resolution: {:.2f} nm".format(
                        wavelengths[self.ind]))
            else:
                self.slider.label.set_text("Calibrated Pixels")
            if self.ind != len(wavelengths):
                number = 11
                cbar.set_clim(vmin=0, vmax=maximum)
                cbar_ticks = np.linspace(0.,
                                         maximum,
                                         num=number,
                                         endpoint=True)
            else:
                number = 2
                cbar.set_clim(vmin=0, vmax=1)
                cbar_ticks = np.linspace(0., 1, num=number)
            cbar.set_ticks(cbar_ticks)
            cbar.draw_all()
            plt.draw()
		def adv_window():
			figs, axx = plt.subplots(num = 'Advanced settings')
			figs.canvas.mpl_connect('key_press_event', adv_exit)
			axx.axis('off')
			#bx1_as = plt.axes([0.05, 0.3, 0.15, 0.11])
			#bx1_as.set_axis_off()

			#button_as1 = CheckButtons(bx1_as, ['lambda1'], [1])
			axx.axis('off')
			ax_as1 = plt.axes([0.15, 0.02, 0.5, 0.05])
			slider_as1 = Slider(ax_as1, 'lambda1', 0.1, 4, dragging = True, valstep = 0.1)

			ax_as2 = plt.axes([0.15, 0.10, 0.5, 0.05])
			slider_as2 = Slider(ax_as2, 'lambda2', 0.1, 4, dragging = True, valstep = 0.1)

			ax_as3 = plt.axes([0.15, 0.18, 0.5, 0.05])
			slider_as3 = Slider(ax_as3, 'smoothing', 0, 4, dragging = True, valstep = 1)

			ax_as4 = plt.axes([0.15, 0.26, 0.5, 0.05])
			slider_as4 = Slider(ax_as4, 'iterations', 1, 1000, dragging = True, valstep = 1)

			ax_as5 = plt.axes([0.15, 0.34, 0.5, 0.05])
			slider_as5 = Slider(ax_as5, 'radius', 0.5, 5, dragging = True, valstep = 0.1)

			ax_b1 = plt.axes([0.85, 0.15, 0.07, 0.08])
			ax_b2 = plt.axes([0.85, 0.05, 0.07, 0.08])

			but_as1 = Button(ax_b1, 'exit', color = 'beige', hovercolor = 'beige')
			but_as2 = Button(ax_b2, 'start', color = 'beige', hovercolor = 'beige')

			#ax_textbox = plt.axes([0, 0.4, 0.5, 0.4])
			#axx.axis('off')
			textstr = "Press ENTER in terminal to start segmentation. \n Shouldn't be necessairy to change settings below, but can be tuned if \n resulting segmentation is not ideal. \n Especially if small nodule: try setting lambda1 <= lambda2 \n Or if very nonhomogeneous nodule: try setting lambda1 > lambda2."

			props = dict(boxstyle='round', facecolor='wheat')
			axx.text(-0.18, 0.25, textstr, transform=ax.transAxes, fontsize=12,
        verticalalignment='top', bbox=props)


			slider_as1.set_val(lambda1)
			slider_as2.set_val(lambda2)
			slider_as3.set_val(smoothing)
			slider_as4.set_val(iterations)
			slider_as5.set_val(rad)
			but_as1.on_clicked(adv_exit)
			but_as2.on_clicked(adv_start)

			figs.canvas.draw_idle()

			return figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5
Exemple #8
0
def plot_hfo_samples(hfo_detection_run: HfoDetectionRun):
    periods = hfo_detection_run.detector.last_run.analytics.periods
    fig_height = 6
    fig_width = 10
    rows = 4
    columns = 1
    fig = plt.figure(figsize=(fig_width, fig_height))

    plt.rc('font', family='sans-serif')

    spec = gridspec.GridSpec(rows, columns, figure=fig, hspace=0.7)

    bandwidth_axes = fig.add_subplot(spec[0, 0])
    spike_train_axes = fig.add_subplot(spec[1, 0])
    raster_axes = fig.add_subplot(spec[2, 0])
    slider_axes = fig.add_subplot(8, 1, 8)

    period_windows = list(zip(periods.start, periods.stop))
    if len(period_windows) == 0:
        return

    slider = Slider(slider_axes,
                    'Period Index\n(Interactive)',
                    1,
                    len(period_windows) + 1,
                    valinit=1,
                    valstep=1.0)

    initial_start, initial_stop = period_windows[0]
    _plot_hfo_sample(hfo_detection_run, np.float64(initial_start),
                     np.float64(initial_stop), bandwidth_axes,
                     spike_train_axes, raster_axes)

    def plot_time(one_based_index):
        start, stop = period_windows[int(np.round(one_based_index - 1))]
        _plot_hfo_sample(hfo_detection_run, np.float64(start),
                         np.float64(stop), bandwidth_axes, spike_train_axes,
                         raster_axes)
        fig.canvas.draw_idle()

    slider.on_changed(plot_time)

    if should_show_plot(hfo_detection_run.configuration):
        plt.show()
    if should_save_plot(hfo_detection_run.configuration):
        for one_based_index in range(1, len(period_windows) + 1):
            slider.set_val(one_based_index)
            save_or_show_channel_plot(f'hfo_sample_period_{one_based_index}',
                                      hfo_detection_run)
Exemple #9
0
class HistoryPlotter:
    def __init__(self,saved_filename):
        self.plotter = RealTimePlotter()
        self.plotter.fig.suptitle('History data', fontsize='14', fontweight='bold')
        self.plotter.fig.subplots_adjust(bottom=0.23,hspace=0.5)
        self.index = 0
        self.saved_filename = saved_filename
        self.data = []
        self.prev_axis = plt.axes([0.395, 0.03, 0.1, 0.06])
        self.next_axis = plt.axes([0.505, 0.03, 0.1, 0.06])
        self.btn_next = Button(self.next_axis, 'Next')
        self.btn_next.on_clicked(self.next)
        self.btn_prev = Button(self.prev_axis, 'Previous')
        self.btn_prev.on_clicked(self.prev)
        self.slider_axis = plt.axes([0.25, 0.11, 0.5, 0.03])
        self.read_data()
        self.slider = Slider(self.slider_axis, 'chunk',0,len(self.data)-1, valinit=0,valfmt='%10.0f')
        self.slider.on_changed(self.update)

    def read_data(self):
        with bz2.BZ2File(self.saved_filename,'r') as f:
            self.data = pickle.load(f)
            print('Number of chunks',len(self.data) )
        

    def loop(self):
        try:
            self.plotter.show()
            self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats'])
            plt.show(block=True)
        except KeyboardInterrupt:
            print("Stopping...")

    def update(self,val):
        self.index=int(val)
        self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats'])

    def next(self,event):
        if(self.index < len(self.data)-1):
            self.index +=1
            self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats'])
            self.slider.set_val(self.index)

    def prev(self,event):
        if(self.index >0):
            self.index -=1
            self.plotter.plot(self.data[self.index]['buffer'],self.data[self.index]['stats'])
            self.slider.set_val(self.index)
Exemple #10
0
class ImgView(object):
    def __init__(self, img, fig_id=None, imin=None, imax=None):
        self.fig = plt.figure(fig_id)
        self.ax = self.fig.add_axes([0.1, 0.25, 0.8, 0.7])
        self.fig.subplots_adjust(bottom=0.25)

        axcolor = 'lightgoldenrodyellow'
        self.axlo = self.fig.add_axes([0.15, 0.1, 0.65, 0.03], axisbg=axcolor)
        self.axhi = self.fig.add_axes([0.15, 0.15, 0.65, 0.03], axisbg=axcolor)
        self.axrefr = self.fig.add_axes([0.15, 0.05, 0.10, 0.03], axisbg=axcolor)

        imsort = np.sort(img.flatten())
        n = len(imsort)
        if imin is None:
            imin = imsort[int(n * 0.005)]
        if imax is None:
            imax = imsort[int(n * 0.995)]
        self.slo = Slider(self.axlo, 'Scale min', imin, imax, valinit=imsort[int(n * 0.02)])
        self.shi = Slider(self.axhi, 'Scale max', imin, imax, valinit=imsort[int(n * 0.98)])
        self.brefr = Button(self.axrefr, 'Refresh')
        self.slo.on_changed(self.update_slider)
        self.shi.on_changed(self.update_slider)
        self.brefr.on_clicked(self.refresh)
        self.set_img(img)

    def update_slider(self, val):
        if self.shi.val <= self.slo.val:
            self.shi.set_val(self.slo.val + 1)
        self.imgplt.set_clim(self.slo.val, self.shi.val)
        self.fig.canvas.draw()

    def set_img(self, img=None):
        if img is not None:
            self.img = img
        self.ax.set_xlim(0, self.img.shape[1]-1)
        self.ax.set_ylim(0, self.img.shape[0]-1)
        self.imgplt = self.ax.imshow(self.img, vmin=self.slo.val, vmax=self.shi.val,
                                     cmap='hot', origin='lower',
                                     interpolation='nearest')
        self.imgplt.set_cmap('hot')
        divider = make_axes_locatable(self.ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        self.fig.colorbar(self.imgplt, cax=cax)
        self.fig.canvas.draw()

    def refresh(self, event=None):
        self.set_img()
    class Index(object):
        def __init__(self, ax_slider, ax_prev, ax_next):
            self.ind = 0
            self.num = len(wavelengths)
            self.bnext = Button(ax_next, 'Next')
            self.bnext.on_clicked(self.next)
            self.bprev = Button(ax_prev, 'Previous')
            self.bprev.on_clicked(self.prev)
            self.slider = Slider(ax_slider,
                                 "Energy Resolution: {:.2f} nm".format(wavelengths[0]), 0,
                                 self.num, valinit=0, valfmt='%d')
            self.slider.valtext.set_visible(False)
            self.slider.label.set_horizontalalignment('center')
            self.slider.on_changed(self.update)

            position = ax_slider.get_position()
            self.slider.label.set_position((0.5, -0.5))
            self.slider.valtext.set_position((0.5, -0.5))

        def next(self, event):
            i = (self.ind + 1) % (self.num + 1)
            self.slider.set_val(i)

        def prev(self, event):
            i = (self.ind - 1) % (self.num + 1)
            self.slider.set_val(i)

        def update(self, i):
            self.ind = int(i)
            image.set_data(R[self.ind])
            if self.ind != len(wavelengths):
                self.slider.label.set_text("Energy Resolution: {:.2f} nm"
                                           .format(wavelengths[self.ind]))
            else:
                self.slider.label.set_text("Calibrated Pixels")
            if self.ind != len(wavelengths):
                number = 11
                cbar.set_clim(vmin=0, vmax=maximum)
                cbar_ticks = np.linspace(0., maximum, num=number, endpoint=True)
            else:
                number = 2
                cbar.set_clim(vmin=0, vmax=1)
                cbar_ticks = np.linspace(0., 1, num=number)
            cbar.set_ticks(cbar_ticks)
            cbar.draw_all()
            plt.draw()
Exemple #12
0
    def set_val(self, val):
        """
        Set the value and update the color.

        Notes
        -----
        valmin/valmax are set on the parent to 0 and len(depths).
        """
        val = int(val)
        # valmax is not allowed, since it is out of the array.
        # valmin is allowed since 0 index is in depth array.
        if val < self.valmin or val >= self.valmax:
            # invalid, so ignore
            return
        # activate color is first since we still have access to self.val
        self.updatePageDepthColor(val)
        Slider.set_val(self, val)
Exemple #13
0
class GUI:
    
    def __init__(self, coinToss):
        self._coinToss = coinToss
        maxTosses = coinToss.tosses
        
        axcolor = 'lightgoldenrodyellow'
        
        self.figure = pyplot.figure()
        self.mainAxis = pyplot.axes([0.05, 0.2, 0.9, 0.75])
        
        # Slider for number of tosses
        tossAxis = pyplot.axes([0.1, 0.05, 0.8, 0.05])
        self.tossSlider = Slider(tossAxis, 'Tosses', 0., 1.*maxTosses, valinit=0, valfmt=u'%d')
        self.tossSlider.on_changed(lambda x: self.draw(x))

        # Reset button        
        resetAxis = pyplot.axes([0.8, 0.85, 0.1, 0.05])
        self.resetButton = Button(resetAxis, 'Re-toss', color=axcolor, hovercolor='0.975')
        self.resetButton.on_clicked(self.retoss)
        
        # Key press events
        self.figure.canvas.mpl_connect('key_press_event', lambda x: self.press(x))
        
        self._coinToss.doTosses()

    def retoss(self, event):
        self._coinToss.doTosses()
        self.draw(self.tossSlider.val)

    def press(self, event):
        if event.key == u'left' and self.tossSlider.val > self.tossSlider.valmin:
            self.tossSlider.set_val(self.tossSlider.val - 1)
        if event.key == u'right' and self.tossSlider.val < self.tossSlider.valmax:
            self.tossSlider.set_val(self.tossSlider.val + 1)
        if event.key == u'r':
            self.retoss(event)

    def draw(self, x = 0):
        c = self._coinToss
        m = max(enumerate(c.posterior[int(x),:]), key=operator.itemgetter(1))
        self.mainAxis.clear()
        self.mainAxis.plot(c.conditional,c.posterior[int(x),:], lw=2, color='red')
        self.mainAxis.vlines(c.conditional[m[0]],0,c.posterior.max())
        self.figure.canvas.draw()
    class IndexTracker(object):
        def __init__(self, ax, X, step=41):
            self.ax = ax
            ax.figure.subplots_adjust(left=0.25, bottom=0.25)
            ax.set_title('use scroll wheel to navigate images')

            self.step = step
            self.X = X
            self.slices, rows, cols = X.shape
            self.ind = 0

            self.im = ax.imshow(self.X[self.ind, :, :],
                                vmin=np.min(X),
                                vmax=np.max(X))

            ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
            self.slider = Slider(ax,
                                 'Axis %i index' % self.slices,
                                 0,
                                 self.slices,
                                 valinit=self.ind,
                                 valfmt='%i')
            self.slider.on_changed(self.update_slider)

            self.update()

        def press(self, event):
            if event.key == 'right':
                self.ind = (self.ind + self.step) % self.slices
            elif event.key == 'left':
                self.ind = (self.ind - self.step) % self.slices
            self.slider.set_val(self.ind)
            self.update()

        def update_slider(self, event):
            ind = int(self.slider.val)
            self.ind = ind
            self.update()

        def update(self):
            self.im.set_data(self.X[self.ind, :, :].T)
            ax.set_ylabel('slice %s' % self.ind)
            self.im.axes.figure.canvas.draw()
    def plot_data(self):

        import matplotlib.pyplot as plt
        from matplotlib.widgets import Slider

        fig = plt.figure()
        ax0 = plt.axes([0.085, 0.2, 0.9, 0.75])
        ax1 = plt.axes([0.21, 0.02, 0.7, 0.03])

        self.draw_vals(ax0)

        B = Slider(ax1,
                   label='Beta value (1e-8)',
                   valmin=4,
                   valmax=50,
                   valstep=0.1)
        Slider.set_val(B, 10)

        Slider.on_changed(B, self.update_beta)
        plt.show()
Exemple #16
0
def animate_plot_to_pictures(min_speed, max_speed, pictures):
    axcolor = 'lightgoldenrodyellow'
    axspeed = plt.axes([0.125, 0.05, 0.65, 0.03], facecolor=axcolor)
    sspeed = Slider(axspeed, 'Speed', 0, max_speed, valinit=0, valstep=1)

    for i in range(pictures):
        speed = 1.0 * i / (pictures-1) * max_speed + min_speed
        accels, rads = strafevis.strafe_stats.get_stats(720, strafevis.strafe_stats.StatType.ACCEL, speed=speed)
        display_axes = plt.subplot(1, 1, 1, polar=True)
        norm = mpl.colors.Normalize(0.0, 2 * np.pi)
        cmap = AngleMap(accels)

        cb = mpl.colorbar.ColorbarBase(display_axes, cmap=cmap,
                                       norm=norm,
                                       orientation='horizontal')

        # aesthetics - get rid of border and axis labels
        cb.outline.set_visible(False)
        display_axes.set_axis_off()
        sspeed.set_val(speed)
        plt.savefig('pic_%04d.png' % i)
class Visualizer(object):
    def __init__(self, snapshot, save_on_close=None):
        self.snapshot = snapshot
        self.save_on_close = save_on_close
        self.T = self.snapshot.human.values()[0].T
        self.choice = None
    @property
    def t(self):
        return self._t
    @t.setter
    def t(self, value):
        self._t = value
        for scene in self.scenes:
            scene.t = self.t
    def select(self, key):
        self.choice = key
        plt.close(self.fig)
    def close(self, event):
        if self.save_on_close is not None:
            self.fig.savefig(self.save_on_close)
    def run(self):
        self.fig, self.ax = plt.subplots(1, len(self.snapshot.keys()), sharex=True, sharey=True, figsize=(13, 7))
        self.fig.canvas.mpl_connect('key_press_event', self.key_press)
        self.fig.canvas.mpl_connect('close_event', self.close)
        self.scenes = [Scene(ax, self.snapshot.view(key)) for ax, key in zip(self.ax, self.snapshot.keys())]

        self.fig.subplots_adjust(bottom=0.15, top=0.85)

        box = self.fig.add_axes([0.15, 0.05, 0.7, 0.05])
        self.slider = Slider(box, 'Time', 0., self.T, valinit=0.)
        self.t = 0.
        def update_t(t):
            self.t = t
        self.slider.on_changed(update_t)

        def click(key):
            def f(event):
                self.select(key)
            return f
        self.buttons = []
        for ax, key in zip(self.ax, self.snapshot.keys()):
            box = ax.figbox
            box = self.fig.add_axes([box.x0, box.y1+0.05, box.width, 0.05])
            self.buttons.append(Button(box, 'Prefer {}'.format(key)))
            self.buttons[-1].on_clicked(click(key))

        plt.show()
    def key_press(self, event):
        if event.key=='escape':
            plt.close(self.fig)
        elif event.key=='r':
            self.slider.set_val(0.)
        elif event.key=='up':
            self.slider.set_val(min(max(self.t+0.2, 0), self.T))
        elif event.key=='down':
            self.slider.set_val(min(max(self.t-0.2, 0), self.T))
        elif event.key.lower() in [s.lower() for s in self.snapshot.keys()]:
            for key in self.snapshot.keys():
                if event.key.lower()==key.lower():
                    self.select(key)
Exemple #18
0
class slicer(object):
    def __init__(self, data, axis, init_slice):
        self.axis = axis
        self.axis_label = {0:'X', 1:'Y', 2:'Z'}[axis]
        
        self.data = np.swapaxes(data, self.axis, 0)
        self.slice_index = init_slice
        self.slice = self.data[self.slice_index]
        self.max_index = self.data.shape[0]
        
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.ax.imshow(self.slice.T, interpolation='none', origin='lower')
        self.along_axis = plt.axes([0.2, 0.1, 0.65, 0.03])
        self.slab = Slider(self.along_axis, '%s_Slab'%self.axis_label, 0, self.max_index , valinit=self.slice_index,  valfmt='%i')
        self.slab.on_changed(self.update_figure)
        self.fig.canvas.mpl_connect('key_press_event',self.update_slice_index)
        plt.show()
        
    def draw(self):
        im = your_function(self.values)
        pylab.show()
        self.ax.imshow(im)

    def update_slice_index(self, event):
        if event.key=='+':
            self.slice_index += 1
        elif event.key == '-':
            self.slice_index -= 1
            if self.slice_index < 0:
                self.slice_index = self.max_index
        self.slab.set_val(self.slice_index)
        
    def update_figure(self, event = None):
        self.slice_index = int(self.slab.val%self.max_index)
        self.slice = self.data[self.slice_index]
        self.ax.imshow(self.slice.T, interpolation='none', origin='lower')
        self.fig.canvas.draw()
Exemple #19
0
class InteractivePlot(object):

    def __init__(self, record_wins, sliding_wins, example_labels,
                 similarities, is_test=None, plot_rc=None):
        self.score_plot = ScorePlot(record_wins, sliding_wins, example_labels,
                                    similarities, is_test=is_test,
                                    plot_rc=plot_rc)
        self.fig, self.main_ax = self.score_plot.fig, self.score_plot.main_ax
        plt.subplots_adjust(bottom=0.2)
        self.score_plot.draw()
        max_time = self.score_plot.records.absolute_end
        ax_s = plt.axes([0.15, 0.1, 0.75, 0.03])
        ax_w = plt.axes([0.15, 0.05, 0.75, 0.03])
        self.slider_start = Slider(ax_s, 'Start', 0., max_time, valinit=0.)
        self.slider_start.on_changed(self.update)
        self.slider_width = Slider(ax_w, 'Width', 0., 30., valinit=15.)
        self.slider_width.on_changed(self.update)
        self.fig.canvas.mpl_connect('key_press_event', self.on_key)

    def update(self, val):
        self.score_plot.current.absolute_start = self.slider_start.val
        width = self.slider_width.val
        self.score_plot.current.absolute_end = (
                self.score_plot.current.absolute_start + width)
        self.score_plot.draw()

    def on_key(self, event):
        if event.key == 'right':
            direction = 1
        elif event.key == 'left':
            direction = -1
        else:
            return
        self.slider_start.set_val(self.slider_start.val +
                                  direction * .5 * self.slider_width.val)
        self.update(None)
Exemple #20
0
    def CreateDisplay(self):
        rax = plt.axes([0.025, 0.8, 0.15, 0.15])
        radioSelectOperation = RadioButtons(rax, ("Search", "Insert"),
                                            active=0)
        radioSelectOperation.on_clicked(self.OnOperationTypeSelect)
        radioSelectOperation.set_active(0)

        axAS = plt.axes([0.25, 0.20, 0.65, 0.03])
        axC = plt.axes([0.25, 0.15, 0.65, 0.03])
        axGS = plt.axes([0.25, 0.10, 0.65, 0.03])

        sAtomicSize = Slider(axAS,
                             'AtomicSize',
                             0,
                             len(self.AtomicSize) - 1,
                             valinit=0,
                             valfmt="%1.2f")
        sAtomicSize.on_changed(partial(self.setAS_slider, sAtomicSize))
        sAtomicSize.set_val(0.0)

        sCapacity = Slider(axC,
                           'Capacity',
                           0,
                           len(self.Capacity) - 1,
                           valinit=0,
                           valfmt="%i")
        sCapacity.on_changed(partial(self.setC_slider, sCapacity))
        sCapacity.set_val(0.0)

        sGridSize = Slider(axGS,
                           'GridSize',
                           0,
                           len(self.GridSize) - 1,
                           valinit=0,
                           valfmt="%i")
        sGridSize.on_changed(partial(self.setGS_slider, sGridSize))
        sGridSize.set_val(0.0)
        plt.show()
Exemple #21
0
class AtlasEditor(plot_support.ImageSyncMixin):
    """Graphical interface to view an atlas in multiple orthogonal 
    dimensions and edit atlas labels.
    
    :attr:`plot_eds` are dictionaries of keys specified by one of
    :const:`magmap.config.PLANE` plane orientations to Plot Editors.
    
    Attributes:
        image5d: Numpy image array in t,z,y,x,[c] format.
        labels_img: Numpy image array in z,y,x format.
        channel: Channel of the image to display.
        offset: Index of plane at which to start viewing in x,y,z (user) 
            order.
        fn_close_listener: Handle figure close events.
        borders_img: Numpy image array in z,y,x,[c] format to show label 
            borders, such as that generated during label smoothing. 
            Defaults to None. If this image has a different number of 
            labels than that of ``labels_img``, a new colormap will 
            be generated.
        fn_show_label_3d: Function to call to show a label in a 
            3D viewer. Defaults to None.
        title (str): Window title; defaults to None.
        fn_refresh_atlas_eds (func): Callback for refreshing other
            Atlas Editors to synchronize them; defaults to None.
            Typically takes one argument, this ``AtlasEditor`` object
            to refreshing it. Defaults to None.
        alpha_slider: Matplotlib alpha slider control.
        alpha_reset_btn: Maplotlib button for resetting alpha transparency.
        alpha_last: Float specifying the previous alpha value.
        interp_planes: Current :class:`InterpolatePlanes` object.
        interp_btn: Matplotlib button to initiate plane interpolation.
        save_btn: Matplotlib button to save the atlas.
        fn_status_bar (func): Function to call during status bar updates
            in :class:`pixel_display.PixelDisplay`; defaults to None.
        fn_update_coords (func): Handler for coordinate updates, which
            takes coordinates in z-plane orientation; defaults to None.
    """

    _EDIT_BTN_LBLS = ("Edit", "Editing")

    def __init__(self,
                 image5d,
                 labels_img,
                 channel,
                 offset,
                 fn_close_listener,
                 borders_img=None,
                 fn_show_label_3d=None,
                 title=None,
                 fn_refresh_atlas_eds=None,
                 fig=None,
                 fn_status_bar=None):
        """Plot ROI as sequence of z-planes containing only the ROI itself."""
        super().__init__()
        self.image5d = image5d
        self.labels_img = labels_img
        self.channel = channel
        self.offset = offset
        self.fn_close_listener = fn_close_listener
        self.borders_img = borders_img
        self.fn_show_label_3d = fn_show_label_3d
        self.title = title
        self.fn_refresh_atlas_eds = fn_refresh_atlas_eds
        self.fig = fig
        self.fn_status_bar = fn_status_bar

        self.alpha_slider = None
        self.alpha_reset_btn = None
        self.alpha_last = None
        self.interp_planes = None
        self.interp_btn = None
        self.save_btn = None
        self.edit_btn = None
        self.color_picker_box = None
        self.fn_update_coords = None

        self._labels_img_sitk = None  # for saving labels image

    def show_atlas(self):
        """Set up the atlas display with multiple orthogonal views."""
        # set up the figure
        if self.fig is None:
            fig = figure.Figure(self.title)
            self.fig = fig
        else:
            fig = self.fig
        fig.clear()
        gs = gridspec.GridSpec(2,
                               1,
                               wspace=0.1,
                               hspace=0.1,
                               height_ratios=(20, 1),
                               figure=fig,
                               left=0.06,
                               right=0.94,
                               bottom=0.02,
                               top=0.98)
        gs_viewers = gridspec.GridSpecFromSubplotSpec(2,
                                                      2,
                                                      subplot_spec=gs[0, 0])

        # set up a colormap for the borders image if present
        cmap_borders = colormaps.get_borders_colormap(self.borders_img,
                                                      self.labels_img,
                                                      config.cmap_labels)
        coord = list(self.offset[::-1])

        # editor controls, split into a slider sub-spec to allow greater
        # spacing for labels on either side and a separate sub-spec for
        # buttons and other fields
        gs_controls = gridspec.GridSpecFromSubplotSpec(1,
                                                       2,
                                                       subplot_spec=gs[1, 0],
                                                       width_ratios=(1, 1),
                                                       wspace=0.15)
        self.alpha_slider = Slider(
            fig.add_subplot(gs_controls[0, 0]),
            "Opacity",
            0.0,
            1.0,
            valinit=plot_editor.PlotEditor.ALPHA_DEFAULT)
        gs_controls_btns = gridspec.GridSpecFromSubplotSpec(
            1, 5, subplot_spec=gs_controls[0, 1], wspace=0.1)
        self.alpha_reset_btn = Button(fig.add_subplot(gs_controls_btns[0, 0]),
                                      "Reset")
        self.interp_btn = Button(fig.add_subplot(gs_controls_btns[0, 1]),
                                 "Fill Label")
        self.interp_planes = InterpolatePlanes(self.interp_btn)
        self.interp_planes.update_btn()
        self.save_btn = Button(fig.add_subplot(gs_controls_btns[0, 2]), "Save")
        self.edit_btn = Button(fig.add_subplot(gs_controls_btns[0, 3]), "Edit")
        self.color_picker_box = TextBox(
            fig.add_subplot(gs_controls_btns[0, 4]), None)

        # adjust button colors based on theme and enabled status; note
        # that colors do not appear to refresh until fig mouseover
        for btn in (self.alpha_reset_btn, self.edit_btn):
            enable_btn(btn)
        enable_btn(self.save_btn, False)
        enable_btn(self.color_picker_box, color=config.widget_color + 0.1)

        def setup_plot_ed(axis, gs_spec):
            # set up a PlotEditor for the given axis

            # subplot grid, with larger height preference for plot for
            # each increased row to make sliders of approx equal size and
            # align top borders of top images
            rows_cols = gs_spec.get_rows_columns()
            extra_rows = rows_cols[3] - rows_cols[2]
            gs_plot = gridspec.GridSpecFromSubplotSpec(
                2,
                1,
                subplot_spec=gs_spec,
                height_ratios=(1, 10 + 14 * extra_rows),
                hspace=0.1 / (extra_rows * 1.4 + 1))

            # transform arrays to the given orthogonal direction
            ax = fig.add_subplot(gs_plot[1, 0])
            plot_support.hide_axes(ax)
            plane = config.PLANE[axis]
            arrs_3d, aspect, origin, scaling = \
                plot_support.setup_images_for_plane(
                    plane,
                    (self.image5d[0], self.labels_img, self.borders_img))
            img3d_tr, labels_img_tr, borders_img_tr = arrs_3d

            # slider through image planes
            ax_scroll = fig.add_subplot(gs_plot[0, 0])
            plane_slider = Slider(ax_scroll,
                                  plot_support.get_plane_axis(plane),
                                  0,
                                  len(img3d_tr) - 1,
                                  valfmt="%d",
                                  valinit=0,
                                  valstep=1)

            # plot editor
            max_size = max_sizes[axis] if max_sizes else None
            plot_ed = plot_editor.PlotEditor(
                ax,
                img3d_tr,
                labels_img_tr,
                config.cmap_labels,
                plane,
                aspect,
                origin,
                self.update_coords,
                self.refresh_images,
                scaling,
                plane_slider,
                img3d_borders=borders_img_tr,
                cmap_borders=cmap_borders,
                fn_show_label_3d=self.fn_show_label_3d,
                interp_planes=self.interp_planes,
                fn_update_intensity=self.update_color_picker,
                max_size=max_size,
                fn_status_bar=self.fn_status_bar)
            return plot_ed

        # setup plot editors for all 3 orthogonal directions
        max_sizes = plot_support.get_downsample_max_sizes()
        for i, gs_viewer in enumerate(
            (gs_viewers[:2, 0], gs_viewers[0, 1], gs_viewers[1, 1])):
            self.plot_eds[config.PLANE[i]] = setup_plot_ed(i, gs_viewer)
        self.set_show_crosslines(True)

        # attach listeners
        fig.canvas.mpl_connect("scroll_event", self.scroll_overview)
        fig.canvas.mpl_connect("key_press_event", self.on_key_press)
        fig.canvas.mpl_connect("close_event", self._close)
        fig.canvas.mpl_connect("axes_leave_event", self.axes_exit)

        self.alpha_slider.on_changed(self.alpha_update)
        self.alpha_reset_btn.on_clicked(self.alpha_reset)
        self.interp_btn.on_clicked(self.interpolate)
        self.save_btn.on_clicked(self.save_atlas)
        self.edit_btn.on_clicked(self.toggle_edit_mode)
        self.color_picker_box.on_text_change(self.color_picker_changed)

        # initialize and show planes in all plot editors
        if self._max_intens_proj is not None:
            self.update_max_intens_proj(self._max_intens_proj)
        self.update_coords(coord, config.PLANE[0])

        plt.ion()  # avoid the need for draw calls

    def _close(self, evt):
        """Handle figure close events by calling :attr:`fn_close_listener`
        with this object.

        Args:
            evt (:obj:`matplotlib.backend_bases.CloseEvent`): Close event.

        """
        self.fn_close_listener(evt, self)

    def on_key_press(self, event):
        """Respond to key press events.
        """
        if event.key == "a":
            # toggle between current and 0 opacity
            if self.alpha_slider.val == 0:
                # return to saved alpha if available and reset
                if self.alpha_last is not None:
                    self.alpha_slider.set_val(self.alpha_last)
                self.alpha_last = None
            else:
                # make translucent, saving alpha if not already saved
                # during a halve-opacity event
                if self.alpha_last is None:
                    self.alpha_last = self.alpha_slider.val
                self.alpha_slider.set_val(0)
        elif event.key == "A":
            # halve opacity, only saving alpha on first halving to allow
            # further halving or manual movements while still returning to
            # originally saved alpha
            if self.alpha_last is None:
                self.alpha_last = self.alpha_slider.val
            self.alpha_slider.set_val(self.alpha_slider.val / 2)
        elif event.key == "up" or event.key == "down":
            # up/down arrow for scrolling planes
            self.scroll_overview(event)
        elif event.key == "w":
            # shortcut to toggle editing mode
            self.toggle_edit_mode(event)
        elif event.key == "ctrl+s" or event.key == "cmd+s":
            # support default save shortcuts on multiple platforms;
            # ctrl-s will bring up save dialog from fig, but cmd/win-S
            # will bypass
            self.save_fig(self.get_save_path())

    def update_coords(self, coord, plane_src=config.PLANE[0]):
        """Update all plot editors with given coordinates.
        
        Args:
            coord: Coordinate at which to center images, in z,y,x order.
            plane_src: One of :const:`magmap.config.PLANE` to specify the 
                orientation from which the coordinates were given; defaults 
                to the first element of :const:`magmap.config.PLANE`.
        """
        coord_rev = libmag.transpose_1d_rev(list(coord), plane_src)
        for i, plane in enumerate(config.PLANE):
            coord_transposed = libmag.transpose_1d(list(coord_rev), plane)
            if i == 0:
                self.offset = coord_transposed[::-1]
                if self.fn_update_coords:
                    # update offset based on xy plane, without centering
                    # planes are centered on the offset as-is
                    self.fn_update_coords(coord_transposed, False)
            self.plot_eds[plane].update_coord(coord_transposed)

    def view_subimg(self, offset, shape):
        """Zoom all Plot Editors to the given sub-image.

        Args:
            offset: Sub-image coordinates in ``z,y,x`` order.
            shape: Sub-image shape in ``z,y,x`` order.
        
        """
        for i, plane in enumerate(config.PLANE):
            offset_tr = libmag.transpose_1d(list(offset), plane)
            shape_tr = libmag.transpose_1d(list(shape), plane)
            self.plot_eds[plane].view_subimg(offset_tr[1:], shape_tr[1:])
        self.fig.canvas.draw_idle()

    def refresh_images(self, plot_ed=None, update_atlas_eds=False):
        """Refresh images in a plot editor, such as after editing one
        editor and updating the displayed image in the other editors.
        
        Args:
            plot_ed (:obj:`magmap.plot_editor.PlotEditor`): Editor that
                does not need updating, typically the editor that originally
                changed. Defaults to None.
            update_atlas_eds (bool): True to update other ``AtlasEditor``s;
                defaults to False.
        """
        for key in self.plot_eds:
            ed = self.plot_eds[key]
            if ed != plot_ed: ed.refresh_img3d_labels()
            if ed.edited:
                # display save button as enabled if any editor has been edited
                enable_btn(self.save_btn)
        if update_atlas_eds and self.fn_refresh_atlas_eds is not None:
            # callback to synchronize other Atlas Editors
            self.fn_refresh_atlas_eds(self)

    def scroll_overview(self, event):
        """Scroll images and crosshairs in all plot editors
        
        Args:
            event: Scroll event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].scroll_overview(event)

    def alpha_update(self, event):
        """Update the alpha transparency in all plot editors.
        
        Args:
            event: Slider event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].alpha_updater(event)

    def alpha_reset(self, event):
        """Reset the alpha transparency in all plot editors.
        
        Args:
            event: Button event, currently ignored.
        """
        self.alpha_slider.reset()

    def axes_exit(self, event):
        """Trigger axes exit for all plot editors.
        
        Args:
            event: Axes exit event.
        """
        for key in self.plot_eds:
            self.plot_eds[key].on_axes_exit(event)

    def interpolate(self, event):
        """Interpolate planes using :attr:`interp_planes`.
        
        Args:
            event: Button event, currently ignored.
        """
        try:
            self.interp_planes.interpolate(self.labels_img)
            # flag Plot Editors as edited so labels can be saved
            for ed in self.plot_eds.values():
                ed.edited = True
            self.refresh_images(None, True)
        except ValueError as e:
            print(e)

    def save_atlas(self, event):
        """Save atlas labels using the registered image suffix given by
        :attr:`config.reg_suffixes[config.RegSuffixes.ANNOTATION]`.
        
        Args:
            event: Button event, currently not used.
        
        """
        # only save if at least one editor has been edited
        if not any([ed.edited for ed in self.plot_eds.values()]): return

        # save to the labels reg suffix; use sitk Image if loaded and store
        # any Image loaded during saving
        reg_name = config.reg_suffixes[config.RegSuffixes.ANNOTATION]
        if self._labels_img_sitk is None:
            self._labels_img_sitk = config.labels_img_sitk
        self._labels_img_sitk = sitk_io.write_registered_image(
            self.labels_img,
            config.filename,
            reg_name,
            self._labels_img_sitk,
            overwrite=True)

        # reset edited flag in all editors and show save button as disabled
        for ed in self.plot_eds.values():
            ed.edited = False
        enable_btn(self.save_btn, False)
        print("Saved labels image at {}".format(datetime.datetime.now()))

    def get_save_path(self):
        """Get figure save path based on filename, ROI, and overview plane
         shown.
        
        Returns:
            str: Figure save path.

        """
        ext = config.savefig if config.savefig else config.DEFAULT_SAVEFIG
        return "{}.{}".format(
            naming.get_roi_path(os.path.basename(self.title), self.offset),
            ext)

    def toggle_edit_mode(self, event):
        """Toggle editing mode, determining the current state from the
        first :class:`magmap.plot_editor.PlotEditor` and switching to the 
        opposite value for all plot editors.

        Args:
            event: Button event, currently not used.
        """
        edit_mode = False
        for i, ed in enumerate(self.plot_eds.values()):
            if i == 0:
                # change edit mode based on current mode in first plot editor
                edit_mode = not ed.edit_mode
                toggle_btn(self.edit_btn, edit_mode, text=self._EDIT_BTN_LBLS)
            ed.edit_mode = edit_mode
        if not edit_mode:
            # reset the color picker text box when turning off editing
            self.color_picker_box.set_val("")

    def update_color_picker(self, val):
        """Update the color picker :class:`TextBox` with the given value.

        Args:
            val (str): Color value. If None, only :meth:`color_picker_changed`
                will be triggered.
        """
        if val is None:
            # updated picked color directly
            self.color_picker_changed(val)
        else:
            # update text box, which triggers color_picker_changed
            self.color_picker_box.set_val(val)

    def color_picker_changed(self, text):
        """Respond to color picker :class:`TextBox` changes by updating
        the specified intensity value in all plot editors.

        Args:
            text (str): String of text box value. Converted to an int if
                non-empty.
        """
        intensity = text
        if text:
            if not libmag.is_number(intensity): return
            intensity = int(intensity)
        print("updating specified color to", intensity)
        for i, ed in enumerate(self.plot_eds.values()):
            ed.intensity_spec = intensity
Exemple #22
0
class VideoViewer(object):
    """
    A matplotlib-based video viewer.
    
    Parameters
    ----------
    video : list-like, iterator
        A list of a tuple of 2D arrays or a generator of a tuple of 2D arrays. 
        If an iterator is provided, you must set 'count' as well. 
    count: int
        Length of the video. When this is set it displays only first 'count' frames of the video.
    id : int, optional
        For multi-frame data specifies camera index.
    norm_func : callable
        Normalization function that takes a single argument (array) and returns
        a single element (array). Can be used to apply custom normalization 
        function to the image before it is shown.    
    title : str, optional
        Plot title.
    kw : options, optional
        Extra arguments passed directly to imshow function
        
    Examples
    --------    
    
    >>> from cddm.viewer import VideoViewer
    >>> video = (np.random.randn(256,256) for i in range(256))
    >>> vg = VideoViewer(video, 256, title = "iterator example") #must set nframes, because video has no __len__  
    
    #>>> vg.show()
    
    >>> video = [np.random.randn(256,256) for i in range(256)] 
    >>> vl = VideoViewer(video, title = "list example") 
    
    #>>> vl.show()  
    """
    def __init__(self,
                 video,
                 count=None,
                 id=0,
                 norm_func=lambda x: x.real,
                 title="",
                 **kw):

        if count is None:
            try:
                count = len(video)
            except TypeError:
                raise Exception("You must specify count!")

        self._norm = norm_func

        self.id = id
        self.index = 0
        self.video = video
        self.fig, self.ax = plt.subplots()
        self.ax.set_title(title)
        plt.subplots_adjust(bottom=0.25)

        frame = next(iter(video))  #take first frame
        frame = self._prepare_image(frame)
        self.img = self.ax.imshow(frame, **kw)

        self.fig.colorbar(self.img, ax=self.ax)

        self.axframe = plt.axes([0.1, 0.1, 0.7, 0.03])
        self.sframe = Slider(self.axframe,
                             '',
                             0,
                             count - 1,
                             valinit=0,
                             valstep=1,
                             valfmt='%i')

        self.axnext = plt.axes([0.7, 0.02, 0.1, 0.05])
        self.bnext = Button(self.axnext, '>')

        self.axnext2 = plt.axes([0.6, 0.02, 0.1, 0.05])
        self.bnext2 = Button(self.axnext2, '>>>')

        self.axprev = plt.axes([0.1, 0.02, 0.1, 0.05])
        self.bprev = Button(self.axprev, '<')

        self.axprev2 = plt.axes([0.2, 0.02, 0.1, 0.05])
        self.bprev2 = Button(self.axprev2, '<<<')

        self.axstop = plt.axes([0.4, 0.02, 0.1, 0.05])
        self.bstop = Button(self.axstop, 'Stop')

        self.axplay = plt.axes([0.3, 0.02, 0.1, 0.05])
        self.bplay = Button(self.axplay, 'Play')

        self.axfast = plt.axes([0.5, 0.02, 0.1, 0.05])
        self.bfast = Button(self.axfast, 'FF')

        self.playing = False
        self.step_fast = count / 100
        self.step = 1
        self.pause_duration = 0.001

        def _play():
            while self.playing:
                plt.pause(self.pause_duration)
                next_frame = self.sframe.val + self.step
                if next_frame >= count:
                    self.playing = False
                else:
                    self.sframe.set_val(next_frame)

        def stop(event):
            self.playing = False

        @skip_runtime_error
        def play(event):
            self.playing = True
            self.step = 1
            _play()

        @skip_runtime_error
        def play_fast(event):
            self.playing = True
            self.step = self.step_fast
            _play()

        def next_frame(event):
            self.sframe.set_val(min(self.sframe.val + 1, count - 1))

        def next_fast(event):
            self.sframe.set_val(min(self.sframe.val + self.step, count - 1))

        def prev_frame(event):
            self.sframe.set_val(max(self.sframe.val - 1, 0))

        def prev_fast(event):
            self.sframe.set_val(max(self.sframe.val - self.step, 0))

        @skip_runtime_error
        def update(val):
            i = int(self.sframe.val)
            try:
                frame = self.video[i]  #assume list-like object
                self.index = i
            except TypeError:
                #assume generator
                frame = None
                if i > self.index:
                    for frame in self.video:
                        self.index += 1
                        if self.index >= i:
                            break
            if frame is not None:
                frame = self._prepare_image(frame)
                self.img.set_data(frame)
                self.fig.canvas.draw_idle()

        self.sframe.on_changed(update)

        self.bnext.on_clicked(next_frame)
        self.bstop.on_clicked(stop)
        self.bplay.on_clicked(play)
        self.bfast.on_clicked(play_fast)
        self.bnext2.on_clicked(next_fast)
        self.bprev2.on_clicked(prev_fast)
        self.bprev.on_clicked(prev_frame)

    def _prepare_image(self, im):
        if isinstance(im, tuple) or isinstance(im, list):
            return self._norm(im[self.id])
        else:
            return self._norm(im)

    def show(self):
        """Shows video."""
        plt.show()
Exemple #23
0
class Control:
    def __init__(self, figure, position, label, initial_value, dim):
        """
            Control group for slice
        """
        # Slider
        self.slider = Slider(figure.add_axes(position,
                                             xticks=[],
                                             yticks=[],
                                             facecolor='#222222'),
                             label,
                             0,
                             dim - 1,
                             valinit=initial_value,
                             valstep=1,
                             valfmt='%1.0f',
                             color='#444444')

        # Set Button Positions
        position[1] = position[1] - 0.04
        position[2] = position[2] / 2
        position[3] = 0.04
        self.buttondown = Button(figure.add_axes(position),
                                 '-',
                                 color='#222222',
                                 hovercolor='#333333')

        # Buttons
        position[0] = position[0] + position[2]
        self.buttonup = Button(figure.add_axes(position),
                               '+',
                               color='#222222',
                               hovercolor='#333333')

        # save value for display
        self.value = initial_value

        # save dim for slider limit
        self.lim = dim

    def get_value(self):
        """
            Returns the current value of the control
        """
        return int(self.value)

    # decrement slider
    def decrement(self, event):
        new_val = self.value - 1
        if new_val >= 0:
            self.slider.set_val(new_val)

    # increment slider
    def increment(self, event):
        new_val = self.value + 1
        if new_val < self.lim:
            self.slider.set_val(new_val)

    # update control
    def update(self, value, callback_func):
        """
            Updates the current value of the control then runs the
            callback_func with the current value
        """
        self.value = int(value)  # set new value
        callback_func(int(value))  # execute callback
Exemple #24
0
class ytViewer(object):
    def __init__(self, filename, fold=19277, nmax=100,NORM=True,dtype=int8,DEB=0,shear_val=0.):
        self.UPDATE = True
        self.color  = True
        
        self.NORM      = NORM
        self.NMAX      = nmax
        self.fold      = fold
        self.increment = 5
        self.index     = 0
        self.shear_val = shear_val
        
        self.remove_len1 = 0
        self.remove_len2 = 1
        
        self.fig = figure(figsize=(16,7))
        
        self.data              = fromfile(filename,dtype=dtype)
        self.max_index         = int(len(self.data)/self.fold)
        self.data              = self.data[:self.max_index*self.fold]
        self.folded_data_orig2 = self.data.reshape(self.max_index,self.fold)
        self.folded_data_orig  = array(self.folded_data_orig2)
        self.folded_data_orig3 = array(self.folded_data_orig)
        self.folded_data       = self.folded_data_orig3[:self.NMAX]
        
        self.Y0 = 0
        
        self.ax = axes([0.1,0.4,0.8,0.47])
        if not self.NORM:
            self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=0, vmax=255)
        else:
            self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=self.data.min(), vmax=self.data.max())

        self.cursor = Cursor(self.ax, useblit=True, color='red', linewidth=2)

        self.axh = axes([0.1,0.05,0.8,0.2])
        self.hline, = self.axh.plot(self.folded_data[self.Y0,:])
        self.axh.set_xlim(0,len(self.folded_data[0,:]))
        self.axh.set_ylim(self.folded_data.min(),self.folded_data.max())
        
        # create 'remove_len1' slider
        self.remove_len1_sliderax = axes([0.1,0.925,0.8,0.02])
        self.remove_len1_slider   = Slider(self.remove_len1_sliderax,'beg',0.,self.fold*(3.5/4),self.remove_len1,'%d')
        self.remove_len1_slider.on_changed(self.update_tab)
        
        # create 'remove_len2' slider
        self.remove_len2_sliderax = axes([0.1,0.905,0.8,0.02])
        self.remove_len2_slider   = Slider(self.remove_len2_sliderax,'end',0.,self.fold*(3.5/4),self.remove_len2,'%d')
        self.remove_len2_slider.on_changed(self.update_tab)
        
        # create 'shear' slider
        self.shear_sliderax = axes([0.1,0.88,0.8,0.02])
        self.shear_slider   = Slider(self.shear_sliderax,'shear',-0.5,0.5,self.shear_val,'%1.3f')
        self.shear_slider.on_changed(self.update_shear)
        
        # create 'index' slider
        self.index_sliderax = axes([0.1,0.975,0.8,0.02])
        self.index_slider   = Slider(self.index_sliderax,'index',0,self.max_index-self.increment,0,'%d')
        self.index_slider.on_changed(self.update_param)
        
        # create 'nmax' slider
        self.nmax_sliderax = axes([0.1,0.955,0.8,0.02])
        self.nmax_slider   = Slider(self.nmax_sliderax,'nmax',0,self.max_index,self.NMAX,'%d')
        self.nmax_slider.on_changed(self.update_tab)
        
        cid  = self.fig.canvas.mpl_connect('motion_notify_event', self.mousemove)
        cid2 = self.fig.canvas.mpl_connect('key_press_event', self.keypress)

        self.axe_toggledisplay  = self.fig.add_axes([0.43,0.27,0.14,0.1])
        self.plot_circle(0,0,2,fc='#00FF7F')
        mpl.pyplot.axis('off')
        
        if self.shear_val!=0.: self.shear()
        gobject.idle_add(self.update_plot)
        show()
    
    def update_shear(self,value):
        self.shear_val = round(self.shear_slider.val,3)
        self.shear()
        self.update_tab(value)
        
    def update_param(self,value):
        self.index     = int(round(self.index_slider.val,0))
        self.update_tab(value)
        
    def shear(self):
        if self.shear_val == 0:
            pass
        dd = array(self.folded_data_orig2)
        for i in range(0,self.folded_data_orig2.shape[0]):
            dd[i,:] = roll(self.folded_data_orig2[i,:], int(i*self.shear_val))
        self.folded_data_orig = dd
        
    def update_tab(self,val):
        self.remove_len1 = int(self.remove_len1_slider.val)
        self.remove_len2 = int(self.remove_len2_slider.val)
        self.NMAX        = int(round(self.nmax_slider.val,0))
        
        self.folded_data_orig3 = array(self.folded_data_orig[:,self.remove_len1:-self.remove_len2])
        self.folded_data = self.folded_data_orig3[self.index:(self.index+self.NMAX)]
        
        self.Y0 = 0
        self.ax.clear()
        self.im     = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto', origin='lower', vmin=self.data.min(), vmax=self.data.max())
        self.axh    = axes([0.1,0.05,0.8,0.2])
        self.axh.clear()
        self.hline, = self.axh.plot(self.folded_data[self.Y0,:])
        plt.ylim(self.folded_data.min(),self.folded_data.max())
        plt.xlim(0,len(self.folded_data[self.Y0,:]))
        draw()

    def update_plot(self):
        while self.UPDATE:
            self.folded_data = self.folded_data_orig3[self.index:(self.index+self.NMAX)]
            
            ### Update picture ###
            self.im.set_data(self.folded_data)
            self.hline.set_ydata(self.folded_data[self.Y0,:])
            self.index = self.index + self.increment
            self.index_slider.set_val(self.index)
            draw()

            return True
        return False

    def update_cut(self):
        self.hline.set_ydata(self.folded_data[self.Y0,:])
        draw()
        
    def keypress(self, event):
        if event.key == 'q': # eXit
            del event
            sys.exit()
        elif event.key=='n':
            del event
            self.NORM = not(self.NORM)
            if not self.NORM:
                self.ax.clear()
                self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto',
                origin='lower', vmin=0, vmax=255)
                self.axh.set_ylim(0, 255)
            else:
                self.ax.clear()
                self.im = self.ax.imshow(self.folded_data, interpolation='nearest', aspect='auto',
                origin='lower', vmin=self.folded_data.min(), vmax=self.folded_data.max())
                self.axh.set_ylim(self.folded_data.min(), self.folded_data.max())
        elif event.key == ' ': # play/pause
            self.toggle_update()
        else:
            print 'Key '+str(event.key)+' not known'
            
    def mousemove(self, event):
        # called on each mouse motion to get mouse position
        if event.inaxes!=self.ax: return
        self.X0 = int(round(event.xdata,0))
        self.Y0 = int(round(event.ydata,0))
        self.update_cut()
        
    def toggle_update(self):
            self.UPDATE = not(self.UPDATE)
            if self.UPDATE:
                gobject.idle_add(self.update_plot)
            self.color  = not(self.color)
            if not(self.color):
                self.patch.remove()
                self.axe_toggledisplay  = self.fig.add_axes([0.43,0.27,0.14,0.1])
                self.axe_toggledisplay.clear()
                self.plot_circle(0,0,2,fc='#FF4500')
                mpl.pyplot.axis('off')
                draw()
            else:
                self.patch.remove()
                self.axe_toggledisplay  = self.fig.add_axes([0.43,0.27,0.14,0.1])
                self.axe_toggledisplay.clear()
                self.plot_circle(0,0,2,fc='#00FF7F')
                mpl.pyplot.axis('off')
                draw()
                #gobject.idle_add(self.update_plot)
    
    def plot_circle(self,x,y,r,fc='r'):
        """Plot a circle of radius r at position x,y"""
        cir = mpl.patches.Circle((x,y), radius=r, fc=fc)
        self.patch = mpl.pyplot.gca().add_patch(cir)
Exemple #25
0
class Notes(object):
    """
    An interactive matplotlib window for ROI drawing, definition of the main
    axis of the eye, and selection of a reference frames where the eye is open.
    """

    def __init__(self, stack, frame_lut, key):
        self.wintitle = "ixtract - " + key['m']
        # stack of frames sampled from the files to be analyzed:
        self.stack = stack
        self.nframes = self.stack.shape[0]
        self.frame_lut = frame_lut # look-up table for frame indices
        self.ref_stack = [] # selected frames
        self.ref_frame_inds = [] # indices of selected frames
        self.mode = 'roi' # init in roi drawing mode
        self.points = np.zeros((2,2)) # holds mouse click coords
        self.roi = None # init ROI and axis variables to None-type
        self.axis = None
        self.line = None # patch variables to display ROI and axis on figure
        self.rect = None

        # set up frame display axis
        self.fig, _ = plt.subplots()
        self.fig.canvas.set_window_title(self.wintitle)
        grid = gs.GridSpec(9, 12)
        self.ax_frame = plt.subplot(grid[:8,:])
        self.ax_frame.axis('off')
        self.ax_frame.imshow(self.stack[0,:,:], cmap='gray')

        # set up slider axis
        self.ax_slider = plt.subplot(grid[8,0:11])
        self.slider = Slider(self.ax_slider, 'Frame', 0, self.nframes-1,
                            valinit=0, valfmt='%d')

        # connect callback functions
        # REMOVE AXIS BINDINGS (AXIS CHECK IS IN CALLBACK FUNCTIONS)
        self.cidpress = self.ax_frame.figure.canvas.mpl_connect(
                        'button_press_event', self.on_click)
        self.cidrelease = self.ax_frame.figure.canvas.mpl_connect(
                        'button_release_event', self.on_release)
        self.slider.on_changed(self.update_frame)
        # disable default matplotlib key bindings
        manager, canvas = self.fig.canvas.manager, self.fig.canvas
        canvas.mpl_disconnect(manager.key_press_handler_id)
        # connect to custom key bindings
        self.cidkey = self.fig.canvas.mpl_connect('key_press_event', self.on_key)

        # display user prompts
        print("Select an ROI and define the main axis of the eye")
        print("    - use 'd' to switch between drawing modes")
        print("    - use 's' to select a few reference frames in which the eye is open")
        print("    - click on the slider or use the arrow keys to scroll between video frames")
        print("    - use 'c' to continue, or 'esc' to quit")

    def on_click(self, event):
        """stores coordintes of click"""

        if event.inaxes != self.ax_frame:
            return
        self.points[0,0], self.points[0,1] = event.xdata, event.ydata

    def on_release(self, event):
        """assigns coordinates of click & release to ROI or axis properties
        and draws the appropriate shape patch onto the figure"""

        if event.inaxes != self.ax_frame:
            return
        # assign click data to temporary variables
        self.points[1,0], self.points[1,1] = event.xdata, event.ydata

        # assign click data as object property and draw patch
        if self.mode == 'roi':

            # parse click data
            self.roi = self.points.copy()
            x0 = min(self.points[:,0])
            y0 = min(self.points[:,1])
            dx = np.absolute(self.points[1,0] - self.points[0,0]) # width
            dy = np.absolute(self.points[1,1] - self.points[0,1]) # height

            # remove old patch
            if self.rect is not None:
                self.rect.remove()

            # add new patch
            self.rect = Rectangle((x0,y0), dx, dy, lw=2, ec=[0,1,0.5], fill=False)
            self.ax_frame.add_patch(self.rect)

        elif self.mode == 'axis':

            # parse click data
            self.axis = self.points.copy()
            x0 = self.points[0,0]
            y0 = self.points[0,1]
            dx = self.points[1,0] - self.points[0,0]
            dy = self.points[1,1] - self.points[0,1]

            # remove old patch
            if self.line is not None:
                self.line.remove()

            # add new patch
            # arrow patch used in lieu of a line patch (arrow head width set to zero)
            self.line = Arrow(x0, y0, dx, dy, width=0, lw=2, color=[1,0,1])
            self.ax_frame.add_patch(self.line)

        # update figure
        self.fig.canvas.draw()

    def on_key(self, event):
        """specifies functions of pressed keys in matplotlib figure"""

        # change drawing mode
        if event.key == 'd':
            if self.mode == 'roi':
                self.mode = 'axis'
            elif self.mode == 'axis':
                self.mode = 'roi'
            print("    Drawing: " + self.mode)

        # select current frame as open-eye reference
        elif event.key == 's':
            newframe = np.int(np.round(self.slider.val))
            print("    Frame %d selected as reference" % newframe)
            self.ref_stack.append(self.stack[newframe,:,:])
            self.ref_frame_inds.append(newframe)

        # change frame
        elif event.key == 'right': # move forward one frame
            if np.round(self.slider.val) < self.nframes-1:
                self.slider.set_val(self.slider.val+1)
        elif event.key == 'left': # move back one frame
            if np.round(self.slider.val) > 0:
                self.slider.set_val(self.slider.val-1)

        # confirm or exit
        elif event.key == 'c': # confirm and continue
            # check that all annotations have been made
            if self.roi is None:
                print("    Please define an ROI")
            elif self.axis is None:
                print("    Please define the eye's aixs")
            elif len(self.ref_stack) == 0:
                print("    Please select at least one open-eye reference frame")
            else:
                print("ROI and axis confirmed")
                self.wrap_up()
                plt.close('all')
        elif event.key == 'escape': # quit eye tracking
            print("Eye tracking aborted")
            plt.close('all')
            sys.exit()

    def update_frame(self, val):
        """Updates frame based on slider value"""

        newframe = np.int(np.round(self.slider.val))
        frame = self.stack[newframe,:,:]
        self.ax_frame.imshow(frame, cmap='gray')

    def wrap_up(self):
        """Performs cropping and reference frame computations before closing"""

        # crop stacks based on selected ROI
        print("    Cropping frames...")
        self.stack = img.crop(self.stack, self.roi)
        self.ref_stack = img.crop(np.array(self.ref_stack), self.roi)

        # take median if multiple frames are given as reference
        if len(self.ref_stack.shape) == 2:
            self.ref_frame = self.ref_stack
        elif len(self.ref_stack.shape) == 3:
            self.ref_frame = np.median(self.ref_stack, axis=0)

        # compute mean pixel-wise differences from the reference frame
        print("    Comparing frames to reference...")
        self.diffs = img.mean_diffs(self.stack, self.ref_frame)

        # store indices of frames used as reference
        self.ref_frames = self.frame_lut[self.ref_frame_inds]
Exemple #26
0
def pager(L,shape,vmake,vpaint,offset=0,save=None,savedefaults=dict(dest='pager-sav',format='svg'),bstyle={},**_ka):
  r"""
:param L: a list of arbitrary objects other than :const:`None`
:param shape: a pair (number of rows, number of columns) or a single number if they are equat
:param vmake: a function to initialise the display (see below)
:param vpaint: a function to instantiate the display for a given page (see below)
:param savedefaults: used as default keyword arguments of method :meth:`savefig` when saving pages
:type savedefaults: :class:`dict`

This function first create a :class:`Cell` instance with all the remaining arguments, then splits it into a grid of sub-cells according to *shape*, then displays *L* page per page on the grid. Each page displays a slice of *L* of length equal to the product of the components of *shape* (or less, for the final page). The toolbar is enriched with page navigation buttons. A save button also allows to save the whole collection of pages in a given directory (beware: may be long).

Function *vmake* takes as input a :class:`Cell` instance and instantiates it as needed. It can store information (e.g. about the specific role of each artist created in the cell), if needed, by simply setting attributes in the cell. This is called once for each cell at the begining of the display.

Function *vpaint* takes as input a cell and an element of *L* or None, and displays that element in the cell (or resets the cell to indicate a missing value), possibly using the artists created by *vmake* and stored in the cell. This is called once at each page display and for each cell.

Unfortunately, matplotlib toolbars are not standardised: the depend on the backend and may not support adding button.
  """
#------------------------------------------------------------------------------
  from numpy import ceil, rint, clip
  from matplotlib.text import Text
  from matplotlib.widgets import Slider
  from matplotlib.pyplot import close
  from pathlib import Path
  from shutil import rmtree
  def gen(L):
    yield from L
    while True: yield None
  def genc(cell):
    Nr,Nc = cell.shape
    yield from (cell[row,col] for row in range(Nr) for col in range(Nc))
  def paintp(cell,p,draw=True):
    for c,x in zip(genc(cell),gen(L[p*cellpp:])): vpaint(c,x)
    if draw: cell.figure.canvas.draw()
  def toggle_ctrl():
    ctrl.ax.set_visible(not ctrl.ax.get_visible())
    cell.figure.canvas.draw()
  def save_all():
    ka = _ka.copy()
    ka.update(fig=None,figsize=((cell.figure.get_figwidth(),cell.figure.get_figheight())))
    #import multiprocessing
    #multiprocessing.get_context('spawn').Process(target=pager,args=(L,shape,vmake,vpaint),kwargs=dict(save={},savedefaults=savedefaults,**ka)).start()
    pager(L,shape,vmake,vpaint,save={},savedefaults=savedefaults,**ka)
  cell = Cell.create(**_ka)
  Nr,Nc = (shape,shape) if isinstance(shape,int) else shape
  cell.make_grid(Nr,Nc)
  for c in genc(cell): vmake(c)
  cellpp = Nr*Nc
  npage = int(ceil(len(L)/cellpp))
  if save is None:
    actions = [
      ('<<',(lambda:ctrl.set_val(clip(ctrl.val-1,1,npage)))),
      ('>>',(lambda:ctrl.set_val(clip(ctrl.val+1,1,npage)))),
      ('toggle-ctrl',toggle_ctrl),
      ('save-all',save_all),
      ]
    try: menu = cell.figure.canvas.toolbar; menu.addAction
    except: menu = Menu(cell.figure,**bstyle)
    for a,f in actions: menu.addAction(a,f)
    ctrl = Slider(cell.figure.add_axes((0.1,0.,.8,.03),visible=False,zorder=1),'page',.5,npage+.5,valinit=0,valfmt='%.0f/{}'.format(npage),closedmin=False,closedmax=False)
    ctrl.on_changed(lambda p:paintp(cell,int(rint(p))-1))
    ctrl.set_val(1+offset/cellpp)
  else:
    s = savedefaults.copy()
    s.update(save)
    pth = Path(s.pop('dest'))
    try:
      assert pth.is_dir()
      for f in list(pth.iterdir()): rmtree(str(f))
    except Exception as e:
      logger.warn('Error on save directory %s: %s',path,e)
      raise
    try:
      for p in range(npage):
        paintp(cell,p,False)
        cell.figure.savefig(str((pth/'p{:02d}'.format(p)).with_suffix('.'+s['format'])),**s)
    except Exception as e: logger.warn('Error saving page %s: %s',p,e)
    close(cell.figure.number)
class InteractiveView:

    def __init__(self, img, peaks):

        import matplotlib.pyplot as plt
        from matplotlib.widgets import Button, Slider

        self.fig = plt.figure(figsize=(15, 10))
        self.ax = self.fig.add_subplot(111)
        plt.subplots_adjust(bottom=0.15)
        self.i = 1
        self.peaks = peaks
        self.img = img

        if self.img.ndim > 2:
            flatted_dim = [reduce(operator.mul, self.img.shape[:-2])]
            self.img.shape = flatted_dim + list(self.img.shape[-2:])

        self.text = plt.figtext(0.06, 0.05, '', transform=self.fig.transFigure)

        w = 0.1
        h = 0.050
        y_pos = 0.04

        self.axprev = plt.axes([0.7, y_pos, w, h])
        self.bprev = Button(self.axprev, 'Previous')
        self.bprev.on_clicked(self.prev)

        self.axnext = plt.axes([0.8, y_pos, w, h])
        self.bnext = Button(self.axnext, 'Next')
        self.bnext.on_clicked(self.next)

        self.axslide = plt.axes([0.4, 0.04, 0.25, 0.03])
        self.slider = Slider(self.axslide, 'Frames', 1, int(len(self.img)),
                             valinit=1)
        self.slider.on_changed(self.slide)

        self.artists = []
        self.im = None
        self.draw(0)

    def draw(self, i):

        from matplotlib.patches import Circle
        import matplotlib.pyplot as plt

        self.i = i

        if i > len(self.img) or i <= 0:
            self.i = 1
            i = 1

        if self.im:
            self.im.remove()

        for art in self.artists:
            art.remove()
        self.artists = []

        try:
            current_peaks = self.peaks.ix[i-1]
            for j, data in current_peaks.iterrows():
                x = data['x']
                y = data['y']
                w = data['w']
                outline = Circle((y, x), w, alpha=0.4, color='red')
                pt = self.ax.add_patch(outline)
                self.artists.append(pt)

                pt = self.ax.scatter(y, x, marker='+')
                self.artists.append(pt)

            n_peaks = current_peaks.shape[0]
        except:
            n_peaks = 0

        self.text.set_text("Frame %i/%i | Peaks number %i" % (i,
                                                              len(self.img),
                                                              n_peaks))

        self.im = self.ax.imshow(self.img[i-1], interpolation='none', cmap='gray', shape=(2, 2))

        plt.draw()

    def next(self, event=None):
        self.slider.set_val(self.i + 1)

    def prev(self, event=None):
        self.slider.set_val(self.i - 1)

    def slide(self, event):
        self.draw(int(event))

    def show(self):
        self.fig.show()
class SelectFromCollection(object):
    
    """Interactive RLS classifier interface for image segmentation

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        The Figure object on which the interface is drawn.
        
    mmc : rlscore.learner.interactive_rls_classifier.InteractiveRlsClassifier
        Interactive RLS classifier object
        
    img : numpy.array
        Array consisting of image data
        
    collection : numpy.array, shape = [n_pixels, 2]
        array consisting of the (x,y) coordinates of all usable pixels in the image
    
    windowsize : int
        Determines the size of a window around grid points (2 * windowsize + 1) 
    """
    
    def __init__(self, fig, mmc, img, collection, windowsize = 0):
        
        #Initialize the main axis
        ax = fig.add_axes([0.1,0.1,0.8,0.8])
        ax.set_yticklabels([])
        ax.yaxis.set_tick_params(size = 0)
        ax.set_xticklabels([])
        ax.xaxis.set_tick_params(size = 0)
        self.imdata = ax.imshow(img)
        
        #Initialize LassoSelector on the main axis
        self.lasso = LassoSelector(ax, onselect = self.onselect)
        self.lasso.connect_event('key_press_event', self.onkeypressed)
        self.lasso.line.set_visible(False)
        
        self.mmc = mmc
        self.img = img
        self.img_orig = img.copy()
        self.collection = collection
        self.selectedset = set([])
        self.lockedset = set([])
        self.windowsize = windowsize
        
        #Initialize the fraction slider
        self.slider_axis = fig.add_axes([0.2, 0.06, 0.6, 0.02])
        self.in_selection_slider = Slider(self.slider_axis,
                                          'Fraction slider',
                                          0.,
                                          1,
                                          valinit = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set))
        def sliderupdate(val):
            val = int(val * len(mmc.working_set))
            nonzeroc = len(np.nonzero(self.mmc.classvec_ws)[0])
            if val > nonzeroc:
                claims = val - nonzeroc
                newclazz = 1
            elif val < nonzeroc:
                claims = nonzeroc - val
                newclazz = 0
            else: return
            print('Claimed', claims, 'points for class', newclazz)
            self.claims = claims
            mmc.claim_n_points(claims, newclazz)
            self.redrawall()
        self.in_selection_slider.on_changed(sliderupdate)
        
        #Initialize the display for the RLS objective funtion
        self.objfun_display_axis = fig.add_axes([0.1, 0.96, 0.8, 0.02])
        self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap = plt.get_cmap("Oranges"))
        self.objfun_display_axis.set_aspect('auto')
        self.objfun_display_axis.set_yticklabels([])
        self.objfun_display_axis.yaxis.set_tick_params(size = 0)
    
    def onselect(self, verts):
        #Select a new working set
        self.path = Path(verts)
        self.selectedset = set(np.nonzero(self.path.contains_points(self.collection))[0])
        print('Selected ' + str(len(self.selectedset)) + ' points')
        newws = list(self.selectedset - self.lockedset)
        self.mmc.new_working_set(newws)
        self.redrawall()
    
    def onkeypressed(self, event):
        print('You pressed', event.key)
        if event.key == '1':
            print('Assigned all selected points to class 1')
            newclazz = 1
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == '0':
            print('Assigned all selected points to class 0')
            newclazz = 0
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == 'a':
            print('Selected all points')
            newws = list(set(range(len(self.collection))) - self.lockedset)
            self.mmc.new_working_set(newws)
            self.lasso.line.set_visible(False)
        if event.key == 'c':
            changecount = mmc.cyclic_descent_in_working_set()
            print('Performed ', changecount, 'cyclic descent steps')
        if event.key == 'l':
            print('Locked the class labels of selected points')
            self.lockedset = self.lockedset | self.selectedset
            newws = list(self.selectedset - self.lockedset)
            self.mmc.new_working_set(newws)
        if event.key == 'u':
            print('Unlocked the selected points')
            self.lockedset = self.lockedset - self.selectedset
            newws = list(self.selectedset - self.lockedset)
            self.mmc.new_working_set(newws)
        if event.key == 'p':
            print('Compute predictions and AUC on data')
            preds = self.mmc.predict(Xmat)
            print(auc(mmc.Y[:, 0], preds[:, 0]))
        self.redrawall()
    
    def redrawall(self):
        #Color all class one labeled pixels red 
        oneclazz = np.nonzero(self.mmc.classvec)[0]
        col_row = self.collection[oneclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        red = np.array([255, 0, 0])
        for i in range(-self.windowsize, self.windowsize + 1):
            for j in range(-self.windowsize, self.windowsize + 1):
                self.img[rowcs+i, colcs+j, :] = red
        
        #Return the original color of the class zero labeled pixels 
        zeroclazz = np.nonzero(self.mmc.classvec - 1)[0]
        col_row = self.collection[zeroclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        for i in range(-self.windowsize, self.windowsize + 1):
            for j in range(-self.windowsize, self.windowsize + 1):
                self.img[rowcs+i, colcs+j, :] = self.img_orig[rowcs+i, colcs+j, :]
        self.imdata.set_data(self.img)
        
        #Update the slider position according to labeling of the current working set
        sliderval = 0
        if len(mmc.working_set) > 0:
            sliderval = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set)
        self.in_selection_slider.set_val(sliderval)
        
        #Update the RLS objective function display
        self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap=plt.get_cmap("Oranges"))
        self.objfun_display_axis.set_aspect('auto')
        
        #Final stuff
        self.lasso.canvas.draw_idle()
        plt.draw()
        print_instructions()
Exemple #29
0
class FindCenters(pb.PointBrowser):
  """
  Semi-automatic fitting of bright points in TEM image
    dragging ... (opt) if True, dragging is allowed for sliders
  """
  def __init__(self, image, dragging=True, **kwargs):

    # init PointBrowser
    super(FindCenters,self).__init__(image,[[None,None]],**kwargs);
    self.axis.set_title('FitHexagonCenters: %s' % self.imginfo['desc']);
    self.fig.subplots_adjust(bottom=0.2);  # space for sliders

    # add slider for neighborhood size
    self.nbhd_size = 5;        # neighborhood size
    axNbhd = self.fig.add_axes([0.2, 0.05, 0.1, 0.04]);
    self.sNbhd = Slider(axNbhd,'neighbors ',2,20,valinit=self.nbhd_size,\
                                valfmt=' (%d)',dragging=dragging);
    self.sNbhd.on_changed(self.ChangeNeighborhood);

    # add slider for number of points
    self.num_points = 1e6;     # number of local maxima to find
    axNum  = self.fig.add_axes([0.45, 0.05, 0.3, 0.04]);
    self.sNum = Slider(axNum, 'points ',0,100,valinit=self.num_points,\
                               valfmt=' (%d)',dragging=dragging);
    self.sNum.on_changed(self.ChangeMaxPoints);

    # add buttons
    axRefine = self.fig.add_axes([0.85, 0.7, 0.1, 0.04]);
    self.bRefine = Button(axRefine,'Refine');
    self.bRefine.on_clicked(self.RefineCenters);

    # initial calculation of local maximas
    self.ChangeNeighborhood(self.nbhd_size);


  def ChangeNeighborhood(self,val):
    #print "ChangeNeighborhood"
    self.nbhd_size = int(val);
    # run initial peak fit
    maxima,diff = self.find_local_maxima(self.image,self.nbhd_size);
    # update max-number of points in points slider
    self.sNum.valmax = Nmax = np.sum(maxima);           # number of local max
    self.sNum.ax.set_xlim((self.sNum.valmin, Nmax));    # rescale slider
    self.sNum.set_val(min(self.num_points,Nmax));       # update value (calls ChangeMaxPoints)
 

  def ChangeMaxPoints(self,val):
    #print "ChangeMaxPoints()";
    self.num_points = int(val);
    self.points = self.refine_local_maxima(self.num_points);
    self._update_points();
    

  def RefineCenters(self,event):
    " refine positions by fitting 2D Gaussian in neighborhood of local max "
    from  scipy.optimize import leastsq
    from  sys import stdout

    #print "Refine()";
    NN    = self.nbhd_size;
    Nx,Ny = self.image.shape;
    dx,dy = np.mgrid[-NN:NN+1,-NN:NN+1];

    # refine each point separately
    self.points = self.points.astype(float); # allow subpixel precision
    for ip in range(len(self.points)):
      P   = self.points[ip];
      x,y = np.round(P);

      # get neighborhood (skip border)
      xmin,xmax = dx[[0,-1],0]+x;  # first and last element in dx
      ymin,ymax = dy[0,[0,-1]]+y;  #       "                   dy
      if xmin<0 or ymin<0 or xmax>=Nx or ymax>=Ny: continue
      nbhd = self.image[xmin:xmax+1,ymin:ymax+1];
      assert nbhd.shape == (2*NN+1,2*NN+1)

      # calculate center of mass
      def gauss(x0,y0,A,B,fwhm):
        return A*np.exp( - ((dx-x0)**2+(dy-y0)**2) / fwhm**2) + B;
      p0     = (0.,0.,self.image[tuple(P)],0.,NN/2);          # initial guess
      residuals = lambda param: (nbhd - gauss(*param)).flat;  # residuals
      p,ierr = leastsq(lambda p: (nbhd - gauss(*p)).flat, p0);# least-squares fit
      self.points[ip] = (x+p[0],y+p[1]);             # correct position of point

      # DEBUG: plot fits for each point
      if self.verbosity > 0:
        print "Refining Points...  %d %%\r" % (100*ip/len(self.points-1)),
      if self.verbosity > 3:
        print "IN:  ",p0
        print "OUT: ",p
      if self.verbosity > 10:
        plt.figure();
        ix = nbhd.shape[0]/2;
        plt.plot(dy[ix],nbhd[ix],  'k',label='image');
        plt.plot(dy[ix],gauss(*p0)[ix],'g',label='first guess');
        plt.plot(dy[ix],gauss(*p)[ix], 'r',label='final fit');
        plt.plot(dx[:,ix],nbhd[:,ix],      'k--');
        plt.plot(dx[:,ix],gauss(*p0)[:,ix],'g--');
        plt.plot(dx[:,ix],gauss(*p)[:,ix], 'r--');
        plt.legend();
        plt.show();
    if self.verbosity > 0:  print "Refining Points. Finished.";
    stdout.flush();

    self._update_points();


  def find_local_maxima(self, data, neighborhood_size):
    """ 
     find local maxima within neighborhood 
      idea from http://stackoverflow.com/questions/9111711
      (get-coordinates-of-local-maxima-in-2d-array-above-certain-value)
    """

    # find local maxima in image (width specified by neighborhood_size)
    data_max = filters.maximum_filter(data,neighborhood_size);
    maxima   = (data == data_max);
    assert np.sum(maxima) > 0;        # we should always find local maxima
  
    # remove connected pixels (plateaus)
    labeled, num_objects = ndimage.label(maxima)
    slices = ndimage.find_objects(labeled)
    maxima *= 0;
    for dx,dy in slices:
      maxima[(dx.start+dx.stop-1)/2, (dy.start+dy.stop-1)/2] = 1

    # calculate difference between local maxima and lowest 
    # pixel in neighborhood (will be used in select_local_maxima)
    data_min = filters.minimum_filter(data,neighborhood_size);
    diff     = data_max - data_min;
    self._maxima = maxima;
    self._diff   = diff;

    return maxima,diff

  def refine_local_maxima(self,N):
    " select highest N local maxima using thresholding "

    maxima = self._maxima;  diff = self._diff;

    # select highest local maxima using thresholding
    if np.sum(maxima) > N:
      # calc treshold from sorted list of differences for local maxima
      thresh = np.sort(diff[maxima].flat)[-N];
      # keep only maxima with diff>thresh
      maxima = np.logical_and(maxima, diff>thresh);  

    # TODO: refine fit by local 2D Gauss-Fit

    # return list of x,y positions of local maxima
    return np.asarray(np.where(maxima)).T; 
Exemple #30
0
class BasicDendrogramViewer(object):

    def __init__(self, dendrogram):

        if dendrogram.data.ndim not in [2, 3]:
            raise ValueError("Only 2- and 3-dimensional arrays are supported at this time")

        self.array = dendrogram.data
        self.dendrogram = dendrogram
        self.plotter = DendrogramPlotter(dendrogram)
        self.plotter.sort(reverse=True)

        # Get the lines as individual elements, and the mapping from line to structure
        self.lines = self.plotter.get_lines()

        # Define the currently selected subtree
        self.selected = None
        self.selected_lines = None
        self.selected_contour = None

        # Initiate plot
        import matplotlib.pyplot as plt
        self.fig = plt.figure(figsize=(14, 8))

        self.ax1 = self.fig.add_axes([0.1, 0.1, 0.4, 0.7])

        from matplotlib.widgets import Slider

        self._clim = (np.min(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]),
                      np.max(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]))

        if self.array.ndim == 2:

            self.slice = None
            self.image = self.ax1.imshow(self.array, origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray)

        else:

            self.slice = int(round(self.array.shape[0] / 2.))
            self.image = self.ax1.imshow(self.array[self.slice, :, :], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray)

            self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03])
            self.slice_slider_ax.set_xticklabels("")
            self.slice_slider_ax.set_yticklabels("")
            self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i")
            self.slice_slider.on_changed(self.update_slice)
            self.slice_slider.drawon = False

        self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03])
        self.vmin_slider_ax.set_xticklabels("")
        self.vmin_slider_ax.set_yticklabels("")
        self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0])
        self.vmin_slider.on_changed(self.update_vmin)
        self.vmin_slider.drawon = False

        self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03])
        self.vmax_slider_ax.set_xticklabels("")
        self.vmax_slider_ax.set_yticklabels("")
        self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1])
        self.vmax_slider.on_changed(self.update_vmax)
        self.vmax_slider.drawon = False

        self.ax2 = self.fig.add_axes([0.6, 0.3, 0.35, 0.4])
        self.ax2.add_collection(self.lines)

        self.selected_label = self.fig.text(0.6, 0.75, "No structure selected", fontsize=18)
        x = [p.vertices[:, 0] for p in self.lines.get_paths()]
        y = [p.vertices[:, 1] for p in self.lines.get_paths()]
        xmin = np.min(x)
        xmax = np.max(x)
        ymin = np.min(y)
        ymax = np.max(y)
        self.lines.set_picker(2.)
        dx = xmax - xmin
        self.ax2.set_xlim(xmin - dx * 0.1, xmax + dx * 0.1)
        self.ax2.set_ylim(ymin * 0.5, ymax * 2.0)
        self.ax2.set_yscale('log')

        self.fig.canvas.mpl_connect('pick_event', self.line_picker)
        self.fig.canvas.mpl_connect('button_press_event', self.select_from_map)

        plt.show()

    def update_slice(self, pos=None):
        if self.array.ndim == 2:
            self.image.set_array(self.array)
        else:
            self.slice = int(round(pos))
            self.image.set_array(self.array[self.slice,:,:])

        self.remove_contour()
        self.update_contour()

        self.fig.canvas.draw()

    def update_vmin(self, vmin):
        if vmin > self._clim[1]:
            self._clim = (self._clim[1], self._clim[1])
        else:
            self._clim = (vmin, self._clim[1])
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()

    def update_vmax(self, vmax):
        if vmax < self._clim[0]:
            self._clim = (self._clim[0], self._clim[0])
        else:
            self._clim = (self._clim[0], vmax)
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()

    def select_from_map(self, event):

        # Only do this if no tools are currently selected
        if event.canvas.toolbar.mode != '':
            return

        if event.inaxes is self.ax1:

            # Find pixel co-ordinates of click
            ix = int(round(event.xdata))
            iy = int(round(event.ydata))

            if self.array.ndim == 2:
                indices = (iy, ix)
            else:
                indices = (self.slice, iy, ix)

            # Select the structure
            structure = self.dendrogram.node_at(indices)
            self.select(structure)

            # Re-draw
            event.canvas.draw()

    def line_picker(self, event):

        # Only do this if no tools are currently selected
        if event.canvas.toolbar.mode != '':
            return

        # event.ind gives the indices of the paths that have been selected

        # Find levels of selected paths
        peaks = [event.artist.structures[i].get_peak(subtree=True)[1] for i in event.ind]

        # Find position of minimum level (may be duplicates, let Numpy decide)
        ind = event.ind[np.argmax(peaks)]

        # Extract structure
        structure = event.artist.structures[ind]

        # If 3-d, select the slice
        if self.array.ndim == 3:
            peak_index = structure.get_peak(subtree=True)
            self.slice_slider.set_val(peak_index[0][0])

        # Select the structure
        self.select(structure)

        # Re-draw
        event.canvas.draw()

    def select(self, structure):

        # Remove previously selected collection
        if self.selected_lines is not None:
            self.ax2.collections.remove(self.selected_lines)
            self.selected_lines = None

        self.remove_contour()

        if structure is None:
            self.selected_label.set_text("No structure selected")
            self.fig.canvas.draw()
            return

        self.selected = structure

        self.selected_label.set_text("Selected structure: {0}".format(structure.idx))

        # Get collection for this substructure
        self.selected_lines = self.plotter.get_lines(structure=structure)
        self.selected_lines.set_color('red')
        self.selected_lines.set_linewidth(2)
        self.selected_lines.set_alpha(0.5)

        # Add to axes
        self.ax2.add_collection(self.selected_lines)

        self.update_contour()

    def remove_contour(self):

        if self.selected_contour is not None:
            for collection in self.selected_contour.collections:
                self.ax1.collections.remove(collection)
            self.selected_contour = None

    def update_contour(self):

        if self.selected is not None:
            mask = self.selected.get_mask(self.array.shape, subtree=True)
            if self.array.ndim == 3:
                mask = mask[self.slice, :, :]
            self.selected_contour = self.ax1.contour(mask, colors='red', linewidths=2, levels=[0.5], alpha=0.5)
Exemple #31
0
class viscm_editor(object):
    def __init__(self, min_Jp=15, max_Jp=95, xp=None, yp=None):
        from .bezierbuilder import BezierModel, BezierBuilder

        axes = _viscm_editor_axes()

        ax_btn_wireframe = plt.axes([0.7, 0.15, 0.1, 0.025])
        self.btn_wireframe = Button(ax_btn_wireframe, 'Show 3D gamut')
        self.btn_wireframe.on_clicked(self.plot_3d_gamut)

        ax_btn_wireframe = plt.axes([0.81, 0.15, 0.1, 0.025])
        self.btn_save = Button(ax_btn_wireframe, 'Save colormap')
        self.btn_save.on_clicked(self.save_colormap)

        ax_btn_props = plt.axes([0.81, 0.1, 0.1, 0.025])
        self.btn_props = Button(ax_btn_props, 'Properties')
        self.btn_props.on_clicked(self.show_viscm)
        self.prop_windows = []

        axcolor = 'None'
        ax_jp_min = plt.axes([0.1, 0.1, 0.5, 0.03], axisbg=axcolor)
        ax_jp_min.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap='gray')
        ax_jp_min.set_xlim(0, 100)

        ax_jp_max = plt.axes([0.1, 0.15, 0.5, 0.03], axisbg=axcolor)
        ax_jp_max.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap='gray')

        self.jp_min_slider = Slider(ax_jp_min, r"$J'_\mathrm{min}$", 0, 100, valinit=min_Jp)
        self.jp_max_slider = Slider(ax_jp_max, r"$J'_\mathrm{max}$", 0, 100, valinit=max_Jp)

        self.jp_min_slider.on_changed(self._jp_update)
        self.jp_max_slider.on_changed(self._jp_update)

        # This is my favorite set of control points so far (just from playing
        # around with things):
        #   min_Jp = 15
        #   max_Jp = 95
        #   xp =
        #     [-4, 27.041103603603631, 84.311067635550557, 12.567076579094476, -9.6]
        #   yp =
        #     [-34, -41.447876447876524, 36.28563443264386, 25.357741755170423, 41]
        # -- njs, 2015-04-05

        if xp is None:
            xp = [-4, 38.289146128951984, 52.1923711457504,
                  39.050944362271053, 18.60872492130315, -9.6]

        if yp is None:
            yp = [-34, -34.34528254916614, -21.594701710471412,
                  31.701084689194829, 29.510846891948262, 41]

        self.bezier_model = BezierModel(xp, yp)
        self.cmap_model = BezierCMapModel(self.bezier_model,
                                          self.jp_min_slider.val,
                                          self.jp_max_slider.val)
        self.highlight_point_model = HighlightPointModel(self.cmap_model, 0.5)

        self.bezier_builder = BezierBuilder(axes['bezier'], self.bezier_model)
        self.bezier_gamut_viewer = GamutViewer2D(axes['bezier'],
                                                 self.highlight_point_model)
        tmp = HighlightPoint2DView(axes['bezier'],
                                   self.highlight_point_model)
        self.bezier_highlight_point_view = tmp

        #draw_pure_hue_angles(axes['bezier'])
        axes['bezier'].set_xlim(-100, 100)
        axes['bezier'].set_ylim(-100, 100)

        self.cmap_view = CMapView(axes['cm'], self.cmap_model)
        self.cmap_highlighter = HighlightPointBuilder(
            axes['cm'],
            self.highlight_point_model)

        print("Click sliders at bottom to change min/max lightness")
        print("Click on colorbar to adjust gamut view")
        print("Click-drag to move control points, ")
        print("  shift-click to add, control-click to delete")

    def plot_3d_gamut(self, event):
        fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
        self.wireframe_view = WireframeView(ax,
                                            self.cmap_model,
                                            self.highlight_point_model)
        plt.show()

    def save_colormap(self, event):
        import textwrap

        template = textwrap.dedent('''
        from matplotlib.colors import LinearSegmentedColormap
        from numpy import nan, inf

        # Used to reconstruct the colormap in pycam02ucs.cm.viscm
        parameters = {{'xp': {xp},
                      'yp': {yp},
                      'min_Jp': {min_Jp},
                      'max_Jp': {max_Jp}}}

        cm_data = {array_list}

        test_cm = LinearSegmentedColormap.from_list(__file__, cm_data)


        if __name__ == "__main__":
            import matplotlib.pyplot as plt
            import numpy as np

            try:
                from pycam02ucs.cm.viscm import viscm
                viscm(test_cm)
            except ImportError:
                print("pycam02ucs not found, falling back on simple display")
                plt.imshow(np.linspace(0, 100, 256)[None, :], aspect='auto',
                           cmap=test_cm)
            plt.show()
        ''')

        rgb, _ = self.cmap_model.get_sRGB(num=256)
        with open('/tmp/new_cm.py', 'w') as f:
            array_list = np.array_repr(rgb, max_line_width=78)
            array_list = array_list.replace('array(', '')[:-1]

            xp, yp = self.cmap_model.bezier_model.get_control_points()

            data = dict(array_list=array_list,
                        xp=xp,
                        yp=yp,
                        min_Jp=self.cmap_model.min_Jp,
                        max_Jp=self.cmap_model.max_Jp)

            f.write(template.format(**data))

            print("*" * 50)
            print("Saved colormap to /tmp/new_cm.py")
            print("*" * 50)

    def show_viscm(self, event):
        cm = LinearSegmentedColormap.from_list(
            'test_cm',
            self.cmap_model.get_sRGB(num=256)[0])
        self.prop_windows.append(viscm(cm, name='test_cm'))
        plt.show()

    def _jp_update(self, val):
        jp_min = self.jp_min_slider.val
        jp_max = self.jp_max_slider.val

        smallest, largest = min(jp_min, jp_max), max(jp_min, jp_max)
        if (jp_min > smallest) or (jp_max < largest):
            self.jp_min_slider.set_val(smallest)
            self.jp_max_slider.set_val(largest)

        self.cmap_model.set_Jp_minmax(smallest, largest)
Exemple #32
0
class GUI(animation.TimedAnimation):
    """
        interface for viewing a movie and its associated roi, traces, other things

        implementation is only through matplotlib, and is non-blocking

        backend affects performance somewhat dramatically. have achieved decent performance with qt5agg and tkagg
    """
    def __init__(self, mov, roi, traces, images={}, cmap=pl.cm.viridis, **kwargs):
        """
            Parameters:
                mov : 3d np array, 0'th axis is time/frames
                roi : 3d np array, one roi per item in 0'th axis, each of which is a True/False mask indicating roi (True=inside roi)
                traces : 2d np array, 0'th axis is time, 1st axis is sources
                images: dictionary of still images

            Attributes:
                roi_kept : boolean array of length of supplied roi, indicating whether or not roi should be kept based on user input

        """

        self.mov = mov
        self.roi_idxs = np.array([np.argwhere(r.flat).squeeze() for r in roi])
        self.roi_centers = np.array([np.mean(np.argwhere(r),axis=0) for r in roi])
        self.roi_orig = roi.copy()
        self.roi = pretty_roi(roi)
        self.roi_kept = np.ones(len(self.roi_idxs)).astype(bool)
        self.traces = traces
        self.images = images

        # figure setup
        self.cmap = cmap
        self.fig = pl.figure()
        NR,NC = 128,32
        gs = gridspec.GridSpec(nrows=NR, ncols=NC)
        gs.update(wspace=0.1, hspace=0.1, left=.04, right=.96, top=.98, bottom=.02)
        # movie axes
        self.ax_contrast0 = self.fig.add_subplot(gs[0:5,0:NC//3])
        self.ax_contrast1 = self.fig.add_subplot(gs[5:10,0:NC//3])
        self.ax_mov = self.fig.add_subplot(gs[10:55,0:NC//3])
        self.ax_mov.axis('off')
        self.ax_img = self.fig.add_subplot(gs[55:100,0:NC//3])
        self.ax_img.axis('off')
        self.axs_imbuts = [self.fig.add_subplot(gs[110:128,idx*2:idx*2+2]) for idx,i in enumerate(self.images)]
        # trace axes
        self.ax_trcs = self.fig.add_subplot(gs[0:64,NC//2:])
        self.ax_trc = self.fig.add_subplot(gs[65:85,NC//2:])
        self.ax_trc.set_xlim([0, len(self.traces)])
        self.ax_nav = self.fig.add_subplot(gs[85:90,NC//2:])
        self.ax_nav.set_xlim([0, len(self.traces)])
        self.ax_nav.axis('off')
        self.ax_rm = self.fig.add_subplot(gs[95:110,NC//2:])

        # interactivity
        self.c0,self.c1= 0,100
        self.sl_contrast0 = Slider(self.ax_contrast0, 'Low', 0., 255.0, valinit=self.c0, valfmt='%d')
        self.sl_contrast1 = Slider(self.ax_contrast1, 'Hi', 0., 255.0, valinit=self.c1, valfmt='%d')
        self.sl_contrast0.on_changed(self.evt_contrast)
        self.sl_contrast1.on_changed(self.evt_contrast)
        self.img_buttons = [Button(ax,k) for k,ax in zip(list(self.images.keys()),self.axs_imbuts)]
        self.but_rm = Button(self.ax_rm, 'Remove All ROIs Currently in FOV')

        # display initial things
        self.movdata = self.ax_mov.imshow(self.mov[0])
        self.movdata.set_animated(True)
        self.roidata = self.ax_mov.imshow(self.roi, cmap=self.cmap, alpha=0.5, vmin=np.nanmin(self.roi), vmax=np.nanmax(self.roi))
        self.trdata, = self.ax_trc.plot(np.zeros(len(self.traces)))
        self.navdata, = self.ax_nav.plot([-2,-2],[-1,np.max(self.traces)], 'r-')
        if len(self.images):
            lab,im = list(self.images.items())[0]
            self.imgdata = self.ax_img.imshow(im)
            self.ax_img.set_ylabel(lab)
        self.plot_current_traces()

        # callbacks
        for ib,lab in zip(self.img_buttons,list(self.images.keys())):
            ib.on_clicked(lambda evt, lab=lab: self.evt_imbut(evt,lab))
        self.but_rm.on_clicked(self.remove_roi)
        self.fig.canvas.mpl_connect('button_press_event', self.evt_click)
        self.ax_mov.callbacks.connect('xlim_changed', self.evt_zoom)
        self.ax_mov.callbacks.connect('ylim_changed', self.evt_zoom)

        # runtime
        self._idx = -1
        self.t0 = time.clock()
        self.always_draw = [self.movdata, self.roidata, self.navdata]
        self.blit_clear_axes = [self.ax_mov, self.ax_nav]

        # parent init
        animation.TimedAnimation.__init__(self, self.fig, interval=40, blit=True, **kwargs)

    @property
    def frame_seq(self):
        #print (time.clock()-self.t0)
        self._idx += 1
        if self._idx == len(self.mov):
            self._idx = 0
        self.navdata.set_xdata([self._idx, self._idx])
        yield self.mov[self._idx]

    @frame_seq.setter
    def frame_seq(self, val):
        pass

    def new_frame_seq(self):
        return self.mov

    def _init_draw(self):
        self._draw_frame(self.mov[0])
        self._drawn_artists = self.always_draw

    def _draw_frame(self, d):
        self.t0 = time.clock()

        self.movdata.set_data(d)

        # blit
        self._drawn_artists = self.always_draw
        for da in self._drawn_artists:
            da.set_animated(True)

    def _blit_clear(self, artists, bg_cache):
        for ax in self.blit_clear_axes:
            if ax in bg_cache:
                self.fig.canvas.restore_region(bg_cache[ax])

    def evt_contrast(self, val):
        self.c0 = self.sl_contrast0.val
        self.c1 = self.sl_contrast1.val

        if self.c0 > self.c1:
            self.c0 = self.c1-1
            self.sl_contrast0.set_val(self.c0)
        if self.c1 < self.c0:
            self.c1 = self.c0+1
            self.sl_contrast1.set_val(self.c1)

        self.movdata.set_clim(vmin=self.c0, vmax=self.c1)
        self.imgdata.set_clim(vmin=self.c0, vmax=self.c1)

    def evt_imbut(self, evt, lab):
        self.imgdata.set_data(self.images[lab])
        self.ax_img.set_title(lab)

    def evt_click(self, evt):
        if not evt.inaxes:
            return

        elif evt.inaxes == self.ax_mov:
            # select roi
            x,y = int(np.round(evt.xdata)), int(np.round(evt.ydata))
            idx = np.ravel_multi_index((y,x), self.roi.shape)
            inside = np.argwhere([idx in ri for ri in self.roi_idxs])
            if len(inside)==0:
                return
            i = inside[0]
            self.set_current_trace(i)

        elif evt.inaxes in [self.ax_nav]:
            x = int(np.round(evt.xdata))
            self._idx = x

    def evt_zoom(self, *args):
        self.plot_current_traces()

    def set_current_trace(self, idx):
        col = self.cmap(np.linspace(0,1,np.sum(self.roi_kept)))[np.squeeze(idx)]
        t = self.traces[:,idx]
        self.trdata.set_ydata(t)
        self.trdata.set_color(col)
        self.ax_trc.set_ylim([t.min(), t.max()])
        self.ax_trc.set_title('ROI {}'.format(idx))
        self.ax_trc.figure.canvas.draw()

    def get_current_roi(self):
        croi = np.array([isin(rc,self.ax_mov) for rc in self.roi_centers])
        croi[self.roi_kept==False] = False
        return croi

    def remove_roi(self, evt):
        self.current_roi = self.get_current_roi()
        self.roi_kept[self.current_roi] = False
        # update
        if np.sum(self.roi_kept):
            proi = pretty_roi(self.roi_orig[self.roi_kept])
            self.roidata.set_data(proi)
            self.roidata.set_clim(vmin=np.nanmin(proi), vmax=np.nanmax(proi))
        else:
            self.roidata.remove()

    def plot_current_traces(self):
        self.current_roi = self.get_current_roi()

        if np.sum(self.current_roi)==0:
            return
        for line in self.ax_trcs.get_lines():
            line.remove()

        cols = self.cmap(np.linspace(0,1,len(self.roi_idxs)))[self.current_roi]
        lastmax = 0
        for t,c in zip(self.traces.T[self.current_roi],cols):
            self.ax_trcs.plot((t-t.min())+lastmax, color=c)
            lastmax = np.max(t)
        self.ax_trcs.set_ylim([0,lastmax])
def view_patches_bar(Yr, A, C, b, f, d1, d2, YrA=None, img=None):
    """view spatial and temporal components interactively

     Parameters:
     -----------
     Yr:    np.ndarray
            movie in format pixels (d) x frames (T)

     A:     sparse matrix
                matrix of spatial components (d x K)

     C:     np.ndarray
                matrix of temporal components (K x T)

     b:     np.ndarray
                spatial background (vector of length d)

     f:     np.ndarray
                temporal background (vector of length T)

     d1,d2: np.ndarray
                frame dimensions

     YrA:   np.ndarray
                 ROI filtered residual as it is given from update_temporal_components
                 If not given, then it is computed (K x T)

     img:   np.ndarray
                background image for contour plotting. Default is the image of all spatial components (d1 x d2)

    """

    pl.ion()
    if 'csc_matrix' not in str(type(A)):
        A = csc_matrix(A)
    if 'array' not in str(type(b)):
        b = b.toarray()

    nr, T = C.shape
    nb = f.shape[0]
    nA2 = np.sqrt(np.array(A.power(2).sum(axis=0))).squeeze()

    if YrA is None:
        Y_r = spdiags(old_div(1, nA2), 0, nr, nr) * (A.T.dot(Yr) -
                                                     (A.T.dot(b)).dot(f) - (A.T.dot(A)).dot(C)) + C
    else:
        Y_r = YrA + C

    if img is None:
        img = np.reshape(np.array(A.mean(axis=1)), (d1, d2), order='F')

    fig = pl.figure(figsize=(10, 10))

    axcomp = pl.axes([0.05, 0.05, 0.9, 0.03])

    ax1 = pl.axes([0.05, 0.55, 0.4, 0.4])
    ax3 = pl.axes([0.55, 0.55, 0.4, 0.4])
    ax2 = pl.axes([0.05, 0.1, 0.9, 0.4])

    s_comp = Slider(axcomp, 'Component', 0, nr + nb - 1, valinit=0)
    vmax = np.percentile(img, 95)

    def update(val):
        i = np.int(np.round(s_comp.val))
        print(('Component:' + str(i)))

        if i < nr:

            ax1.cla()
            imgtmp = np.reshape(A[:, i].toarray(), (d1, d2), order='F')
            ax1.imshow(imgtmp, interpolation='None', cmap=pl.cm.gray, vmax=np.max(imgtmp)*0.5)
            ax1.set_title('Spatial component ' + str(i + 1))
            ax1.axis('off')

            ax2.cla()
            ax2.plot(np.arange(T), Y_r[i], 'c', linewidth=3)
            ax2.plot(np.arange(T), C[i], 'r', linewidth=2)
            ax2.set_title('Temporal component ' + str(i + 1))
            ax2.legend(labels=['Filtered raw data', 'Inferred trace'])

            ax3.cla()
            ax3.imshow(img, interpolation='None', cmap=pl.cm.gray, vmax=vmax)
            imgtmp2 = imgtmp.copy()
            imgtmp2[imgtmp2 == 0] = np.nan
            ax3.imshow(imgtmp2, interpolation='None',
                       alpha=0.5, cmap=pl.cm.hot)
            ax3.axis('off')
        else:
            ax1.cla()
            bkgrnd = np.reshape(b[:, i - nr], (d1, d2), order='F')
            ax1.imshow(bkgrnd, interpolation='None')
            ax1.set_title('Spatial background ' + str(i + 1 - nr))
            ax1.axis('off')

            ax2.cla()
            ax2.plot(np.arange(T), np.squeeze(np.array(f[i - nr, :])))
            ax2.set_title('Temporal background ' + str(i + 1 - nr))

    def arrow_key_image_control(event):

        if event.key == 'left':
            new_val = np.round(s_comp.val - 1)
            if new_val < 0:
                new_val = 0
            s_comp.set_val(new_val)

        elif event.key == 'right':
            new_val = np.round(s_comp.val + 1)
            if new_val > nr + nb:
                new_val = nr + nb
            s_comp.set_val(new_val)
        else:
            pass

    s_comp.on_changed(update)
    s_comp.set_val(0)
    fig.canvas.mpl_connect('key_release_event', arrow_key_image_control)
    pl.show()
Exemple #34
0
		RmaxV = Slider(Rmax, 'Rmax', 1, 254, valinit=rgbinit[0])
		RminV = Slider(Rmin, 'Rmin', 1, 254, valinit=rgbinit[1])
		GmaxV = Slider(Gmax, 'Gmax', 1, 254, valinit=rgbinit[2])
		GminV = Slider(Gmin, 'Gmin', 1, 254, valinit=rgbinit[3])
		BmaxV = Slider(Bmax, 'Bmax', 1, 254, valinit=rgbinit[4])
		BminV = Slider(Bmin, 'Bmin', 1, 254, valinit=rgbinit[5])

		RmaxV.on_changed(sliceupdateRmax)
		RminV.on_changed(sliceupdateRmin)
		GmaxV.on_changed(sliceupdateGmax)
		GminV.on_changed(sliceupdateGmin)
		BmaxV.on_changed(sliceupdateBmax)
		BminV.on_changed(sliceupdateBmin)

		ff=1
	else:
		RmaxV.set_val(rgbinit[0])
		RminV.set_val(rgbinit[1])
		GmaxV.set_val(rgbinit[2])
		GminV.set_val(rgbinit[3])
		BmaxV.set_val(rgbinit[4])
		BminV.set_val(rgbinit[5])
		file = open("rgb.txt", "w")
		file.write(str(rgbinit))
		file.close()
		print rgbinit

	plt.pause(0.001)
	plt.show(block=False)
	#print samp.val,sfreq.val
def view_patches_bar(Yr, A, C, b, f, d1, d2, YrA=None, secs=1, img=None):
    """view spatial and temporal components interactively

     Parameters
     -----------
     Yr:    np.ndarray
            movie in format pixels (d) x frames (T)
     A:     sparse matrix
                matrix of spatial components (d x K)
     C:     np.ndarray
                matrix of temporal components (K x T)
     b:     np.ndarray
                spatial background (vector of length d)

     f:     np.ndarray
                temporal background (vector of length T)
     d1,d2: np.ndarray
                frame dimensions
     YrA:   np.ndarray
                 ROI filtered residual as it is given from update_temporal_components
                 If not given, then it is computed (K x T)

     img:   np.ndarray
                background image for contour plotting. Default is the image of all spatial components (d1 x d2)

    """

    plt.ion()
    nr, T = C.shape
    A2 = A.copy()
    A2.data **= 2
    nA2 = np.sqrt(np.array(A2.sum(axis=0))).squeeze()
    #A = A*spdiags(1/nA2,0,nr,nr)
    #C = spdiags(nA2,0,nr,nr)*C
    b = np.squeeze(b)
    f = np.squeeze(f)
    if YrA is None:
        Y_r = np.array(A.T * np.matrix(Yr) - (A.T * np.matrix(b[:, np.newaxis])) * np.matrix(
            f[np.newaxis]) - (A.T.dot(A)) * np.matrix(C) + C)
    else:
        Y_r = YrA + C

    A = A * spdiags(1 / nA2, 0, nr, nr)
    A = A.todense()
    imgs = np.reshape(np.array(A), (d1, d2, nr), order='F')
    if img is None:
        img = np.mean(imgs[:, :, :-1], axis=-1)

    bkgrnd = np.reshape(b, (d1, d2), order='F')
    fig = plt.figure(figsize=(10, 10))

    axcomp = plt.axes([0.05, 0.05, 0.9, 0.03])

    ax1 = plt.axes([0.05, 0.55, 0.4, 0.4])
#    ax1.axis('off')
    ax3 = plt.axes([0.55, 0.55, 0.4, 0.4])
#    ax1.axis('off')
    ax2 = plt.axes([0.05, 0.1, 0.9, 0.4])
#    axcolor = 'lightgoldenrodyellow'
#    axcomp = plt.axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)

    s_comp = Slider(axcomp, 'Component', 0, nr, valinit=0)
    vmax = np.percentile(img, 98)

    def update(val):
        i = np.int(np.round(s_comp.val))
        print 'Component:' + str(i)

        if i < nr:

            ax1.cla()
            imgtmp = imgs[:, :, i]
            ax1.imshow(imgtmp, interpolation='None', cmap=plt.cm.gray)
            ax1.set_title('Spatial component ' + str(i + 1))
            ax1.axis('off')

            ax2.cla()
            ax2.plot(np.arange(T), np.squeeze(np.array(Y_r[i, :])), 'c', linewidth=3)
            ax2.plot(np.arange(T), np.squeeze(np.array(C[i, :])), 'r', linewidth=2)
            ax2.set_title('Temporal component ' + str(i + 1))
            ax2.legend(labels=['Filtered raw data', 'Inferred trace'])

            ax3.cla()
            ax3.imshow(img, interpolation='None', cmap=plt.cm.gray, vmax=vmax)
            imgtmp2 = imgtmp.copy()
            imgtmp2[imgtmp2 == 0] = np.nan
            ax3.imshow(imgtmp2, interpolation='None', alpha=0.5, cmap=plt.cm.hot)
        else:

            ax1.cla()
            ax1.imshow(bkgrnd, interpolation='None')
            ax1.set_title('Spatial background background')

            ax2.cla()
            ax2.plot(np.arange(T), np.squeeze(np.array(f)))
            ax2.set_title('Temporal background')

    def arrow_key_image_control(event):

        if event.key == 'left':
            new_val = np.round(s_comp.val - 1)
            if new_val < 0:
                new_val = 0
            s_comp.set_val(new_val)

        elif event.key == 'right':
            new_val = np.round(s_comp.val + 1)
            if new_val > nr:
                new_val = nr
            s_comp.set_val(new_val)
        else:
            pass

    s_comp.on_changed(update)
    s_comp.set_val(0)
    id2 = fig.canvas.mpl_connect('key_release_event', arrow_key_image_control)
    plt.show()
Exemple #36
0
class GridCircle:
    """
    Provided a 2D grid, this object:
    > plots the grid with a circle which centre and radius are adjustable,
    > plots values of the grid along the circle alongside the grid.
    """
    def __init__(self,
                 grid,
                 extent=(-1, 1, -1, 1),
                 circle_centre=(0, 0),
                 min=None,
                 max=None,
                 points_theta=100,
                 linear_interpolation=False,
                 show_slider=True):
        """
        Parameters
        ----------
        grid : 2D array-like
            Grid to plot values from.
        extent : scalars (left, right, bottom, top)
            Values of space variables at corners. (default: (-1, 1, -1, 1))
        circle_centre : scalars (x, y)
            Location of the centre of the circle to draw on top of grid.
        min : float
            Minimum value for the colormap. (default: None)
            NOTE: None will be considered as the minimum being the minimum
            value of grid.
        max : float
            Maximum value for the colormap. (default: None)
            NOTE: None will be considered as the maximum being the maximum
            value of grid.
        points_theta : int
            Number of points to consider in the interval [0, 2\\pi] when
            computing values along circle.
        linear_interpolation : bool
            Get value by linear interpolation of neighbouring grid boxes.
            (default: False)
        show_slider : bool
            Display circle radius slider.
        """

        self.circle_centre = np.array(circle_centre)
        self.radius = 0  # radius of the circle
        self.points_theta = points_theta
        self.theta = np.linspace(0, 2 * np.pi, points_theta)
        self.linear_interpolation = linear_interpolation

        self.show_slider = show_slider

        self.fig, (self.ax_grid, self.ax_plot) = plt.subplots(
            1, 2
        )  # matplotlib.figure.Figure object and matplotlib.axes.Axes objects for grid and value plot

        # COLORMAP

        self.min = np.min(grid) if min == None else min
        self.max = np.max(grid) if max == None else max

        self.norm = colors.Normalize(vmin=self.min,
                                     vmax=self.max)  # normalises data
        self.scalarmap = cmx.ScalarMappable(
            norm=self.norm, cmap=cmap)  # scalar map for grid values

        # PLOT

        self.ax_plot.set_ylim(
            [self.min,
             self.max])  # setting y-axis limit of plot as grid extrema
        self.ax_plot.set_xlim([0, 2 * np.pi])  # angle on the cirlce

        self.line, = self.ax_plot.plot(
            np.linspace(0, 2 * np.pi, self.points_theta),
            [0] * self.points_theta)  # plot of values along circle

        # SLIDER

        self.extent = np.array(extent)  # grid extent

        if self.show_slider:

            self.slider_ax = make_axes_locatable(self.ax_plot).append_axes(
                'bottom', size='5%', pad=0.5)  # slider axes
            self.slider = Slider(self.slider_ax,
                                 'radius',
                                 0,
                                 np.min(np.abs(self.extent)),
                                 valinit=self.radius)  # slider

            self.slider.on_changed(
                self.update_slider
            )  # call self.update_slider() on slider update

        # GRID

        #grid = np.array(grid)

        self.grid_plot = self.ax_grid.imshow(grid,
                                             cmap=cmap,
                                             norm=self.norm,
                                             extent=self.extent)  # grid plot

        self.colormap_ax = make_axes_locatable(self.ax_grid).append_axes(
            'right', size='5%', pad=0.05)  # color map axes
        self.colormap = mpl.colorbar.ColorbarBase(
            self.colormap_ax,
            cmap=cmap,
            norm=self.norm,
            orientation='vertical')  # color map

        self.circle = plt.Circle(self.circle_centre,
                                 self.radius,
                                 color='black',
                                 fill=False)  # circle on grid
        self.ax_grid.add_artist(self.circle)

        self.ax_grid.figure.canvas.mpl_connect(
            'button_press_event',
            self.update_grid)  # call self.update_grid() on button press event

        self.update_grid_plot(grid)  # plots grid and updates circle and plot

    def get_fig_ax_cmap(self):
        """
        Returns
        -------
        fig : matplotlib.pyplot.figure object
            Figure.
        (ax_grid, ax_plot) : matplotlib.axes.Axes tuple
            Grid and plot axes.
        colormap : matplotlib.colorbar.ColorbarBase object
            Color map.
        """

        return self.fig, (self.ax_grid, self.ax_plot), self.colormap

    def update_grid_plot(self, grid, extent=None):
        """
        Plots grid.

        Parameters
        ----------
        grid : 2D array-like
            Grid to plot values from.
        extent : scalars (left, right, bottom, top)
            Values of space variables at corners. (default: None)
            NOTE: None will be considered as extent to be self.extent.
        """

        if extent != None: self.extent = np.array(extent)
        self.grid = Grid(grid, extent=self.extent)

        self.grid_plot.set_data(self.grid.grid)  # plots grid
        self.grid_plot.set_extent(self.extent)  # set extent

        self.draw()  # updates circle and plot

    def update_grid(self, event):
        """
        Executes on click on figure.

        Updates radius of cirlce on figure.
        """

        if event.inaxes != self.circle.axes:
            return  # if Axes instance mouse is over is different than circle's figure Axes

        self.radius = np.sqrt(
            np.sum(
                (np.array([event.xdata, event.ydata]) - self.circle_centre)**2)
        )  # radius set to distance between centre of circle and clicked point
        self.slider.set_val(self.radius)  # updates slider value

        self.draw()  # updates figure

    def update_slider(self, event):
        """
        Executes on slider change.

        Updates radius of circle on figure.
        """

        self.radius = self.slider.val  # radius set to slider value

        self.draw()  # updates figure

    def draw(self):
        """
        Updates figure.
        """

        self.line.set_ydata(
            list(
                map(
                    lambda angle: self.grid.get_value_polar(
                        self.radius,
                        angle,
                        centre=self.circle_centre,
                        linear_interpolation=self.linear_interpolation),
                    self.theta)))  # values of the grid along the circle

        self.circle.set_radius(self.radius)  # adjusting circle radius

        self.ax_grid.figure.canvas.draw()  # updating grid
        self.ax_plot.figure.canvas.draw()  # updating plot
Exemple #37
0
class BasicDendrogramViewer(object):

    def __init__(self, dendrogram):

        if dendrogram.data.ndim not in [2, 3]:
            raise ValueError(
                "Only 2- and 3-dimensional arrays are supported at this time")

        self.hub = SelectionHub()
        self._connect_to_hub()

        self.array = dendrogram.data
        self.dendrogram = dendrogram
        self.plotter = DendrogramPlotter(dendrogram)
        self.plotter.sort(reverse=True)

        # Get the lines as individual elements, and the mapping from line to structure
        self.lines = self.plotter.get_lines(edgecolor='k')

        # Define the currently selected subtree
        self.selected_lines = {}
        self.selected_contour = {}
        # The keys in these dictionaries are event button IDs.        


        # Initiate plot
        import matplotlib.pyplot as plt
        self.fig = plt.figure(figsize=(14, 8))

        ax_image_limits = [0.1, 0.1, 0.4, 0.7]

        try:
            from wcsaxes import WCSAxes
            __wcaxes_imported = True
        except ImportError:
            __wcaxes_imported = False
            if self.dendrogram.wcs is not None:
                warnings.warn("`WCSAxes` package required for wcs coordinate display.")


        if self.dendrogram.wcs is not None and __wcaxes_imported:

            if self.array.ndim == 2:
                slices = ('x', 'y')
            else:
                slices = ('x', 'y', 1)

            ax_image = WCSAxes(self.fig, ax_image_limits, wcs=self.dendrogram.wcs, slices=slices)
            self.ax_image = self.fig.add_axes(ax_image)

        else:
            self.ax_image = self.fig.add_axes(ax_image_limits)            

        from matplotlib.widgets import Slider

        self._clim = (np.min(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]),
                      np.max(self.array[~np.isnan(self.array) & ~np.isinf(self.array)]))

        if self.array.ndim == 2:

            self.slice = None
            self.image = self.ax_image.imshow(self.array, origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray)

            self.slice_slider = None

        else:

            if self.array.shape[0] > 1:

                self.slice = int(round(self.array.shape[0] / 2.))

                self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03])
                self.slice_slider_ax.set_xticklabels("")
                self.slice_slider_ax.set_yticklabels("")
                self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i")
                self.slice_slider.on_changed(self.update_slice)
                self.slice_slider.drawon = False

            else:

                self.slice = 0
                self.slice_slider = None

            self.image = self.ax_image.imshow(self.array[self.slice, :,:], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray)

        self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03])
        self.vmin_slider_ax.set_xticklabels("")
        self.vmin_slider_ax.set_yticklabels("")
        self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0])
        self.vmin_slider.on_changed(self.update_vmin)
        self.vmin_slider.drawon = False

        self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03])
        self.vmax_slider_ax.set_xticklabels("")
        self.vmax_slider_ax.set_yticklabels("")
        self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1])
        self.vmax_slider.on_changed(self.update_vmax)
        self.vmax_slider.drawon = False

        self.ax_dendrogram = self.fig.add_axes([0.6, 0.3, 0.35, 0.4])
        self.ax_dendrogram.add_collection(self.lines)

        self.selected_label = {} # map selection IDs -> text objects
        self.selected_label[1] = self.fig.text(0.6, 0.85, "No structure selected", fontsize=18, 
            color=self.hub.colors[1])
        self.selected_label[2] = self.fig.text(0.6, 0.8, "No structure selected", fontsize=18,
            color=self.hub.colors[2])
        self.selected_label[3] = self.fig.text(0.6, 0.75, "No structure selected", fontsize=18,
            color=self.hub.colors[3])
        x = [p.vertices[:, 0] for p in self.lines.get_paths()]
        y = [p.vertices[:, 1] for p in self.lines.get_paths()]
        xmin = np.min(x)
        xmax = np.max(x)
        ymin = np.min(y)
        ymax = np.max(y)
        self.lines.set_picker(2.)
        self.lines.set_zorder(0)
        dx = xmax - xmin
        self.ax_dendrogram.set_xlim(xmin - dx * 0.1, xmax + dx * 0.1)
        self.ax_dendrogram.set_ylim(ymin * 0.5, ymax * 2.0)
        self.ax_dendrogram.set_yscale('log')

        self.fig.canvas.mpl_connect('pick_event', self.line_picker)
        self.fig.canvas.mpl_connect('button_press_event', self.select_from_map)


    def show(self):
        import matplotlib.pyplot as plt
        plt.show()

    def update_slice(self, pos=None):
        if self.array.ndim == 2:
            self.image.set_array(self.array)
        else:
            self.slice = int(round(pos))
            self.image.set_array(self.array[self.slice, :, :])

        self.update_contours()

        self.fig.canvas.draw()

    def _connect_to_hub(self):
        self.hub.add_callback(self._on_selection_change)

    def _on_selection_change(self, selection_id):
        self._update_lines(selection_id)
        self.update_contours()
        self.fig.canvas.draw()

    def update_vmin(self, vmin):
        if vmin > self._clim[1]:
            self._clim = (self._clim[1], self._clim[1])
        else:
            self._clim = (vmin, self._clim[1])
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()

    def update_vmax(self, vmax):
        if vmax < self._clim[0]:
            self._clim = (self._clim[0], self._clim[0])
        else:
            self._clim = (self._clim[0], vmax)
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()

    def select_from_map(self, event):

        # Only do this if no tools are currently selected
        if event.canvas.toolbar.mode != '':
            return
        if event.button not in self.selected_label:
            return

        if event.inaxes is self.ax_image:

            input_key = event.button

            # Find pixel co-ordinates of click
            ix = int(round(event.xdata))
            iy = int(round(event.ydata))

            if self.array.ndim == 2:
                indices = (iy, ix)
            else:
                indices = (self.slice, iy, ix)

            # Select the structure
            structure = self.dendrogram.structure_at(indices)
            self.hub.select(input_key, structure)

            # Re-draw
            event.canvas.draw()

    def line_picker(self, event):

        # Only do this if no tools are currently selected
        if event.canvas.toolbar.mode != '':
            return
        if event.mouseevent.button not in self.selected_label:
            return

        input_key = event.mouseevent.button

        # event.ind gives the indices of the paths that have been selected

        # Find levels of selected paths
        peaks = [event.artist.structures[i].get_peak(subtree=True)[1] for i in event.ind]

        # Find position of minimum level (may be duplicates, let Numpy decide)
        ind = event.ind[np.argmax(peaks)]

        # Extract structure
        structure = event.artist.structures[ind]

        # If 3-d, select the slice
        if self.slice_slider is not None:
            peak_index = structure.get_peak(subtree=True)
            self.slice_slider.set_val(peak_index[0][0])

        # Select the structure
        self.hub.select(input_key, structure)

        # Re-draw
        event.canvas.draw()

    def _update_lines(self, selection_id):
        structures = self.hub.selections[selection_id]
        select_subtree = self.hub.select_subtree[selection_id]

        structure = structures[0]

        # Remove previously selected collection
        if selection_id in self.selected_lines:
            self.ax_dendrogram.collections.remove(self.selected_lines[selection_id])
            del self.selected_lines[selection_id]

        if structure is None:
            self.selected_label[selection_id].set_text("No structure selected")
            self.remove_contour(selection_id)
            self.fig.canvas.draw()
            return

        self.remove_all_contours()

        if len(structures) <= 1:
            label_text = "Selected structure: {0}".format(structure.idx)
        elif len(structures) <=3:
            label_text = "Selected structures: {0}".format(', '.join([str(structure.idx) for structure in structures]))
        else:
            label_text = "Selected structures: {0}...".format(', '.join([str(structure.idx) for structure in structures[:3]]))

        self.selected_label[selection_id].set_text(label_text)

        # Get collection for this substructure
        self.selected_lines[selection_id] = self.plotter.get_lines(
            structures=structures, subtree=select_subtree)
        self.selected_lines[selection_id].set_color(self.hub.colors[selection_id])
        self.selected_lines[selection_id].set_linewidth(2)
        self.selected_lines[selection_id].set_zorder(structure.height)

        # Add to axes
        self.ax_dendrogram.add_collection(self.selected_lines[selection_id])

    def remove_contour(self, selection_id):

        if selection_id in self.selected_contour:
            for collection in self.selected_contour[selection_id].collections:
                self.ax_image.collections.remove(collection)
            del self.selected_contour[selection_id]

    def remove_all_contours(self):
        """ Remove all selected contours. """
        for key in self.selected_contour.keys():
            self.remove_contour(key)

    def update_contours(self):
        self.remove_all_contours()

        for selection_id in self.hub.selections.keys():
            structures = self.hub.selections[selection_id]
            select_subtree = self.hub.select_subtree[selection_id]

            struct = structures[0]
            if struct is None:
                continue

            if select_subtree:
                mask = struct.get_mask(subtree=True)
            else:
                mask = reduce(np.add, [structure.get_mask(subtree=True) for structure in structures])
            if self.array.ndim == 3:
                mask = mask[self.slice, :, :]
            self.selected_contour[selection_id] = self.ax_image.contour(
                mask, colors=self.hub.colors[selection_id],
                linewidths=2, levels=[0.5], alpha=0.75, zorder=struct.height)
Exemple #38
0
class Visualizer:
    def __init__(self, field, fieldname, halospec=None):
        """Initializes a visualization instance, that is a windows with a field
        field is a 3D numpy array
        fieldname is a string with the name of the field
        halospec is a 2x2 array with the definition of the halo size

        After this call the window is shown
        """
        self.field = field
        self.fieldname = fieldname

        # Register halo information
        if halospec is None:
            halospec = [[3, 3], [3, 3]]
        self.istart = halospec[0][0]
        self.iend = field.shape[0] - halospec[0][1]
        self.jstart = halospec[1][0]
        self.jend = field.shape[1] - halospec[1][1]
        self.plotHalo = True
        self.plotLogLog = False

        self.curklevel = 0

        self.figure = plt.figure()

        # Slider
        slideraxes = plt.axes([0.15, 0.02, 0.5, 0.03], axisbg="lightgoldenrodyellow")
        self.slider = Slider(slideraxes, "K level", 0, field.shape[2] - 1, valinit=0)
        self.slider.valfmt = "%2d"
        self.slider.set_val(0)
        self.slider.on_changed(self.updateSlider)

        # CheckButton
        self.cbaxes = plt.axes([0.8, -0.04, 0.12, 0.15])
        self.cbaxes.set_axis_off()
        self.cb = CheckButtons(self.cbaxes, ("Halo", "Logscale"), (self.plotHalo, self.plotLogLog))
        self.cb.on_clicked(self.updateButton)

        # Initial plot
        self.fieldaxes = self.figure.add_axes([0.1, 0.15, 0.9, 0.75])
        self.collection = plt.pcolor(self._getField(), axes=self.fieldaxes)
        self.colorbar = plt.colorbar()
        self.fieldaxes.set_xlim(right=self._getField().shape[1])
        self.fieldaxes.set_ylim(top=self._getField().shape[0])
        plt.xlabel("i")
        plt.ylabel("j")
        self.title = plt.title("%s - Level 0" % (fieldname,))

        plt.show(block=False)

    def updateSlider(self, val):
        if val == self.curklevel:
            return
        self.curklevel = round(val)
        self.title.set_text("%s - Level %d" % (self.fieldname, self.curklevel))

        # Draw new field level
        field = self._getField()
        size = field.shape[0] * field.shape[1]
        array = field.reshape(size)
        self.collection.set_array(array)

        self.colorbar.set_clim(vmin=field.min(), vmax=field.max())
        self.collection.set_clim(vmin=field.min(), vmax=field.max())
        self.colorbar.update_normal(self.collection)
        self.figure.canvas.draw_idle()

    def updateButton(self, label):
        if label == "Halo":
            self.plotHalo = not self.plotHalo
        if label == "Logscale":
            self.plotLogLog = not self.plotLogLog
        self.updatePlot()

    def updatePlot(self):

        # Redraw field
        self.collection.remove()
        field = self._getField()
        if self.plotLogLog:
            minvalue = field.min()
            norm = SymLogNorm(linthresh=1e-10)
            self.collection = plt.pcolor(field, axes=self.fieldaxes, norm=norm)
            self.colorbar.set_clim(vmin=minvalue, vmax=field.max())
        else:
            self.collection = plt.pcolor(field, axes=self.fieldaxes)
            self.colorbar.set_clim(vmin=field.min(), vmax=field.max())
            self.colorbar.set_norm(norm=Normalize(vmin=field.min(), vmax=field.max()))
        self.fieldaxes.set_xlim(right=field.shape[1])
        self.fieldaxes.set_ylim(top=field.shape[0])

        self.colorbar.update_normal(self.collection)
        self.figure.canvas.draw_idle()

    def _getField(self):
        if self.plotHalo:
            return np.rot90(self.field[:, :, self.curklevel])
        else:
            return np.rot90(self.field[self.istart : self.iend, self.jstart : self.jend, self.curklevel])
def plot_biplot(x, information, rhomin, rhomax):
    ## Formatting input
    if callable(information):
        n = 100
        width = 20
 #       x = np.linspace(0, 1, n)
        y = information(x)
    else:
        n = len(information)
        width = int(n/5)
#        x = np.arange(n)
        y = information
        
    array = np.repeat(y, width).reshape((len(y), width))

    # Font type for all plots
    rc('font',**{'family':'serif','serif':['Palation'], 'size':24, 'weight':'bold'})
    rc('text', usetex=True)
    mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

    # Creating plot
#    fig = plt.figure(figsize=(18,10))
    
    fig1, ax1 = plt.subplots(figsize=(2, 10))
#    gs = gridspec.GridSpec(8, 4)
#    ax1 = fig.add_subplot(gs[1:, 0])

    
#    ax1 = fig.add_axes([0.05, 0.05, 0.05, 0.8])      
#    ax2 = fig.add_axes([0.2, 0.05, 0.75, 0.8])
    #ax1 = fig.add_subplot(111)
    #ax2 = fig.add_subplot(111)
    
    colormap_chosen = mpl.colors.LinearSegmentedColormap.from_list('mycolors',['red','yellow']) #['yellow','yellow','#F7F8E0','yellow','red','#8A0808']
    colormap_chosen = 'OrRd_r'#'gist_heat' #hot
    cm = plt.get_cmap(colormap_chosen)
    cNorm  = colors.Normalize(vmin=rhomin, vmax=rhomax)
  
    ax1.imshow(array, cmap=cm, norm=cNorm)
 #   ax2 = fig.add_subplot(gs[1:, 1:5])
 
    fig1 = plt.gcf()
 
    fig2, ax2 = plt.subplots(figsize=(10, 8))
    
    x_0, rho_0 = evolver.rho_aprox(0.5)
    ax2.plot(rho_0, x_0, linewidth=5, linestyle="-", c="grey", zorder=1, alpha=0.3)  
    ax2.plot(information, x, linewidth=8, linestyle="-", c="black", zorder=1)                   
    ax2.plot(information, x, linewidth=5, linestyle="-", c="white", zorder=1)                   

##### TRYING TO MAKE LINES WITH COLOR GRADIENTS    
#    ax2.scatter(information,x,c=range(len(information)), marker='_', s=30)
#    path = mpath.Path(np.column_stack([y, x]))
#    verts = path.interpolated(steps=3).vertices
#    xcolor, ycolor = verts[:, 0], verts[:, 1]
    colorline(ax2, y, x, z=rhomin*(rhomin-y)/(rhomin-rhomax) + rhomax*(rhomax-y)/(rhomax-rhomin), cmap=cm, norm=cNorm, linewidth=5, alpha=1)

#    gs.update(wspace=0.5, hspace=0.5)

    # Stetic tuning of plot
    ax2.set_ylim([min(x),max(x)])
    ax2.set_xlim([990,1093])
    ax2.set_xlabel(r"\textbf{Density / g L}$\boldsymbol{^{-1}}$")
    ax2.set_ylabel(r"\textbf{Height / mm}",rotation=270, labelpad= 10)#fontsize=20,
        
    ax2.yaxis.set_label_position("right")
    ax2.tick_params(labeltop=True)
    ax2.grid(True)    
    plt.setp(ax2.get_xticklabels(), fontsize=18)

    # Final Configuration
    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax1.get_yticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)
    ax1.get_xaxis().set_visible(False)
    for axis in ['top','bottom','left','right']:
         ax1.spines[axis].set_linewidth(5)


#    fig = plt.gcf()


    fig2 = plt.gcf()

    fig3, axslider = plt.subplots(figsize=(10, 1))
    #axslider  = plt.axes([0.3, 0.94, 0.5, 0.04], axisbg=None)
    #    axslider = fig.add_subplot(gs[0,1:3])
    samp = Slider(axslider, 'Time (min)', 0., maxtime, valinit=0,color='grey',alpha=0.3,valfmt='%i'.ljust(5))
    samp.set_val(float(time)/60.)
    
    fig3 = plt.gcf()

    return fig1, fig2, fig3
class py3DSeedEditor:
    """ Viewer and seed editor for 2D and 3D data. 

    py3DSeedEditor(img, ...)

    img: 2D or 3D grayscale data
    voxelsizemm: size of voxel, default is [1, 1, 1]
    initslice: 0
    colorbar: True/False, default is True
    cmap: colormap
    zaxis: axis with slice numbers


    ed = py3DSeedEditor(img)
    ed.show()
    selected_seeds = ed.seeds

    """

    def __init__(self, img, voxelsizemm=[1,1,1], initslice = 0 , colorbar = True,
            cmap = matplotlib.cm.Greys_r, seeds = None, contour = None, zaxis=0,
            mouse_button_map= {1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8},
            windowW = [], windowC = [],
            range_per_slice = False
            ):
        self.fig = plt.figure()

        if len(img.shape) == 2:
            imgtmp = img
            img = np.zeros([1, imgtmp.shape[0], imgtmp.shape[1]])
            img[-1,:,:] = imgtmp
            zaxis = 0

        # Rotate data in depndecy on zaxispyplot
        img = self._rotate_start(img, zaxis)
        seeds = self._rotate_start(seeds, zaxis)
        contour = self._rotate_start(contour, zaxis)

        self.rotated_back = False
        self.zaxis = zaxis

        # if True, intensity range is calculated per slice = better visualisation for
        # higher number of labels
        self.range_per_slice = range_per_slice

        #self.ax = self.fig.add_subplot(111)
        self.imgshape = list(img.shape)
        self.img = img
        self.actual_slice = initslice
        self.colorbar = colorbar
        self.cmap = cmap 
        if seeds == None:
            self.seeds = np.zeros(self.imgshape, np.int8)
        else:
            self.seeds = seeds
        if not (windowW and windowC):
            self.imgmax = np.max(img)
            self.imgmin = np.min(img)
        else:
            self.imgmax = windowC + (windowW / 2)
            self.imgmin = windowC - (windowW / 2)

        """ Mapping mouse button to class number. Default is normal order"""
        self.button_map = mouse_button_map

        self.contour = contour

        self.press = None
        self.press2 = None

# language
        self.texts = {'btn_delete':'Delete', 'btn_close': 'Close'}

        #iself.fig.subplots_adjust(left=0.25, bottom=0.25)
        self.ax = self.fig.add_axes([0.2, 0.3, 0.7,0.6])

        self.draw_slice()

        if self.colorbar:
            self.fig.colorbar(self.imsh)

        # user interface look
        axcolor = 'lightgoldenrodyellow'
        ax_actual_slice = self.fig.add_axes([0.2, 0.2, 0.6, 0.03], axisbg=axcolor)
        self.actual_slice_slider = Slider(ax_actual_slice, 'Slice', 0, 
                self.imgshape[2], valinit=initslice)
        
        # conenction to wheel events
        self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
        self.actual_slice_slider.on_changed(self.sliceslider_update)
# draw
        self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.fig.canvas.mpl_connect('button_release_event', self.on_release)
        self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion)


# delete seeds
        self.ax_delete_seeds = self.fig.add_axes([0.2,0.1,0.1,0.075])
        self.btn_delete = Button(self.ax_delete_seeds, self.texts['btn_delete'])
        self.btn_delete.on_clicked(self.callback_delete)

# close button
        self.ax_delete_seeds = self.fig.add_axes([0.7,0.1,0.1,0.075])
        self.btn_delete = Button(self.ax_delete_seeds, self.texts['btn_close'])
        self.btn_delete.on_clicked(self.callback_close)

        self.draw_slice()


    def _rotate_start(self, data, zaxis):
        if data != None:
            if zaxis == 0:
                data = np.transpose(data,(1,2,0))
            elif zaxis == 2:
                pass
            else:
                print "problem with zaxis in _rotate_start()"

        return data


    def _rotate_end(self, data, zaxis):
        if data != None:
            if self.rotated_back == False:
                if zaxis == 0:
                    data = np.transpose(data,(2,0,1))
                elif zaxis == 2:
                    pass
                else:
                    print "problem with zaxis in _rotate_start()"
            else:
                print "There is a danger in calling show() twice"

        return data


    def update_slice(self):
        #TODO tohle je tu kvuli contour, neumim ji odstranit jinak
        self.ax.cla()
        self.draw_slice()


    def draw_slice(self):
        self.actual_slice = np.int(self.actual_slice)
        sliceimg = self.img[:, :, self.actual_slice]
        if self.range_per_slice:
            self.imsh = self.ax.imshow(sliceimg, self.cmap, vmin=sliceimg.min(), vmax=sliceimg.max(), interpolation='nearest')
        else:
            self.imsh = self.ax.imshow(sliceimg, self.cmap, vmin=self.imgmin, vmax=self.imgmax, interpolation='nearest')
        self.ax.imshow(self.prepare_overlay(self.seeds[:, :, self.actual_slice]), interpolation='nearest', vmin=self.imgmin, vmax=self.imgmax)
        
        if self.contour != None:
            try:
                # exception catch problem with none object in image
                ctr = self.ax.contour(self.contour[:, :, self.actual_slice], 1, linewidths=2)
            except:
                pass

        self.fig.canvas.draw()


    def next_slice(self):
        self.actual_slice = self.actual_slice + 1
        if self.actual_slice >= self.imgshape[2]:
            self.actual_slice = 0


    def prev_slice(self):
        self.actual_slice = self.actual_slice - 1
        if self.actual_slice < 0:
            self.actual_slice = self.imgshape[2] - 1


    def sliceslider_update(self, val):
# zaokrouhlení
        #self.actual_slice_slider.set_val(round(self.actual_slice_slider.val))
        self.actual_slice = round(val)
        self.update_slice()


    def prepare_overlay(self,seeds):
        sh = list(seeds.shape)
        if len(sh) == 2:
            sh.append(4)
        else:
            sh[2] = 4
        overlay = np.zeros(sh)

        overlay[:,:,0] = (seeds == 1)
        overlay[:,:,1] = (seeds == 2)
        overlay[:,:,2] = (seeds == 3)
        overlay[:,:,3] = (seeds > 0)

        return overlay


    def show(self):
        """ Function run viewer window.
        """
        plt.show()
        # Rotate data in depndecy on zaxis
        self.img = self._rotate_end(self.img, self.zaxis)
        self.seeds = self._rotate_end(self.seeds, self.zaxis)
        self.contour = self._rotate_end(self.contour, self.zaxis)
        self.rotated_back = True
        return self.seeds


    def on_scroll(self, event):
        ''' mouse wheel is used for setting slider value'''
        if event.button == 'up':
            self.next_slice()
        if event.button == 'down':
            self.prev_slice()
        self.actual_slice_slider.set_val (self.actual_slice)


## malování -------------------
    def on_press(self, event):
        'on but-ton press we will see if the mouse is over us and store some data'
        if event.inaxes != self.ax: return
        #contains, attrd = self.rect.contains(event)
        #if not contains: return
        #print 'event contains', self.rect.xy
        #x0, y0 = self.rect.xy
        self.press = [event.xdata], [event.ydata], event.button
        #self.press1 = True


    def on_motion(self, event):
        'on motion we will move the rect if the mouse is over us'
        if self.press is None: return

        if event.inaxes != self.ax: return
        #print event.inaxes

        x0, y0, btn = self.press
        x0.append(event.xdata)
        y0.append(event.ydata)


    def on_release(self, event):
        'on release we reset the press data'
        if self.press is None: return
        #print self.press
        x0, y0, btn = self.press
        if btn == 1:
            color = 'r'
        elif btn == 2:
            color = 'b'

        #button Mapping
        btn = self.button_map[btn]

        self.set_seeds(y0, x0, self.actual_slice, btn )

        self.press = None
        self.update_slice()


    def callback_delete(self, event):
        self.seeds[:,:,self.actual_slice] = 0
        self.update_slice()


    def callback_close(self, event):
        matplotlib.pyplot.clf()
        matplotlib.pyplot.close()


    def set_seeds(self, px, py, pz, value = 1, voxelsizemm = [1,1,1], cursorsizemm = [1,1,1]):
        assert len(px) == len(py) , 'px and py describes a point, their size must be same'

        for i, item in enumerate(px):
            self.seeds[item, py[i], pz] = value


    def get_seed_sub(self, label):
        """ Return list of all seeds with specific label
        """
        sx, sy, sz = np.nonzero(self.seeds == label)

        return sx, sy, sz


    def get_seed_val(self, label):
        """ Return data values for specific seed label"""
        return self.img[self.seeds==label]
Exemple #41
0
class Parameters(object):
    """
    An interactive matplotlib window to test and set paramters for eye tracking
    on a subset of frames.
    """

    def __init__(self, notes):
        self.wintitle = notes.wintitle
        self.stack = notes.stack # cropped frame stack
        self.frame_dims = notes.stack.shape[-2:] # cropped frame dimensions
        self.nframes = notes.stack.shape[0] # number of frames
        _, _, _, mmperpx = geom.parse_axis(notes.axis) # mm to pixel conversion factor
        self.diffs = notes.diffs # mean pixel-wise differences from reference frame

        # status variable, set to True on 'c' keypress to exit parameter-setting
        # loop and perform eye tracking on all files:
        self.done = False

        # define a parameter dictionary with some initial values
        # first specify maximum/ initial values for some of the parameters
        method = 'convolve' # eye tracking method ('threshold' or 'convolve')
        shape = 'lse_ellipse' # shape fitting methos ('lse_ellipse or 'min_enclosing')
        # parameters for image pre-processing:
        eq_sp_maxval = 1 # bihist equalization separation point
        eq_rp_maxval = 1 # bihist equalization range point
        c_sig_maxval = 255 / 2 # intensity space sigma for bilateral filter
        s_sig_maxval = max(self.frame_dims) / 8 # spatial sigma for bilateral filter
        k_size_maxval = max(self.frame_dims) / 8 # kernel size for median filter
        k_size_inival = np.int(np.round(k_size_maxval/2))
        if k_size_inival % 2 == 0: # must be an odd integer
            k_size_inival += 1
        # parameters for threshold method:
        dark_thr_inival = 255 * 0.25 # threshold for dark areas
        light_thr_inival = 255 * 0.75 # threshold for light areas
        area_thr_maxval = np.pi * (0.5 / mmperpx)**2 # area threshold
        # parameters for convolve method:
        conv_size_maxval = 0.5 / mmperpx # kernel size for convolution
        conv_size_inival = np.int(np.round(conv_size_maxval/4))
        if conv_size_inival % 2 == 0: # must be an odd integer
            conv_size_inival += 1
        rad_maxval = np.round(2/mmperpx).astype('int') # max radius for edges
        # parameters for blink detection
        b_thr_inival = np.median(self.diffs)*3
        # build parameter dictionary
        self.params = {'eq':False, 'eq_sp':0, 'eq_rp':0, 'k_size':k_size_inival,
                        'c_sig':c_sig_maxval/2, 's_sig':s_sig_maxval/2,
                        'dark_thr':20, 'light_thr':235, 'area_thr':0,
                        'conv_size':conv_size_inival, 'max_rad':rad_maxval/2,
                        'shape':shape, 'ht_fit':False, 'method':method,
                        'blink_thr':b_thr_inival,}

        # set initial frames to display
        self.frame_ind = 0 # current frame index
        self.frame = self.stack[self.frame_ind,:,:] # original frame
        # processed (filtered and equalized) frame:
        self.p_frame = img.pre_process(self.frame, self.params)
        # inverted binarized frame (dark contours):
        self.b_frame, _ = img.binarize(self.p_frame, self.params)
        # p_frame convolved with a black square (center = argmin[c_frame])
        self.c_frame, self.center = img.square_convolve(self.p_frame, self.params['k_size'])
        self.g_frame = img.gradient(self.p_frame) # gradient magnitude of p_frame
        self.edge_pts = geom.starburst(self.g_frame, self.center, self.params['max_rad'], 100)

        # set up figure with plots and images
        self.fig, _ = plt.subplots()
        self.fig.canvas.set_window_title(self.wintitle)
        grid = gs.GridSpec(12, 12)
        # original frame display axis
        self.ax_frame = plt.subplot(grid[:6, :4])
        self.ax_frame.axis('off')
        # processed frame display axis
        self.ax_pframe = plt.subplot(grid[6:9, :2])
        self.ax_pframe.axis('off')
        self.ax_pframe.set_title('Processed Frame')
        # dark binary display axis
        self.ax_bframe = plt.subplot(grid[6:9, 2:4])
        self.ax_bframe.axis('off')
        self.ax_bframe.set_title('Dark Contours')
        # convolution display axis
        self.ax_cframe = plt.subplot(grid[9:12, :2])
        self.ax_cframe.axis('off')
        self.ax_cframe.set_title('Convolution')
        self.c_dot = self.ax_cframe.scatter(self.center[0], self.center[1], s=8)
        # gradient display axis
        self.ax_gframe = plt.subplot(grid[9:12, 2:4])
        self.ax_gframe.axis('off')
        self.ax_gframe.set_title('Gradient')
        self.e_dots = self.ax_gframe.scatter(self.edge_pts[:,0], self.edge_pts[:,1],
                                            s=1, color='C0')

        # intensity histogram
        self.ax_hist = plt.subplot(grid[:4, 4:7])
        self.hist = self.ax_hist.hist(self.p_frame.ravel(), bins=np.arange(256),
                                        color='C0', alpha=0.5, density=True)
        self.dark_thr = self.ax_hist.axvline(self.params['dark_thr'], color=[0,0,0],
                                            marker=' ', label='Dark Threshold')
        self.light_thr = self.ax_hist.axvline(self.params['light_thr'], color=[.75,.75,.75],
                                            marker=' ', label='Light Threshold')
        self.ax_hist.set_title('Intensity Histogram')
        self.ax_hist.legend()
        # timecourse of mean differences between each frame and the reference
        self.ax_diffs = plt.subplot(grid[4:8, 4:7])
        self.ax_diffs.scatter(np.linspace(0,len(self.diffs),num=len(self.diffs)),
                                self.diffs, s=3, color='C0')
        self.ax_diffs.set_xlim([0, self.nframes])
        self.f_dot = self.ax_diffs.scatter(0, self.diffs[0], s=4, color='C1')
        self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'], marker=' ',
                                                color=[0,0,0], label='Blink Threshold')
        self.ax_diffs.legend()
        self.ax_diffs.set_ylabel('Diff. from reference')
        # slider for frame scrolling
        self.ax_fslider = plt.subplot(grid[8, 4:7])
        self.f_slider = Slider(self.ax_fslider, 'Frame', 0, self.nframes-1,
                                    valinit=0, valfmt='%d')
        # distribution of the differences
        self.ax_dist = plt.subplot(grid[9:12, 4:7])
        self.ax_dist.hist(self.diffs.ravel(), bins=100, color='C0', density=True)
        self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'], marker=' ',
                                                color=[0,0,0])
        self.ax_dist.set_xlabel('Diff. from reference')

        # slider for eq_sp equalization parameter
        self.ax_spslider = plt.subplot(grid[0, 8:11])
        self.sp_slider = Slider(self.ax_spslider, 'EQ Sep. Point', 0, 1,
                                valinit=self.params['eq_sp'], valfmt='%02f')
        # slider for eq_rp equalization parameter
        self.ax_rpslider = plt.subplot(grid[1, 8:11])
        self.rp_slider = Slider(self.ax_rpslider, 'EQ Range Point', 0, 1,
                                valinit=self.params['eq_rp'], valfmt='%02f')
        # slider for c_sig filter parameter
        self.ax_cslider = plt.subplot(grid[3, 8:11])
        self.c_slider = Slider(self.ax_cslider, 'Color Sigma', 0, c_sig_maxval,
                                valinit=self.params['c_sig'], valfmt='%d')
        # slider for s_sig filter parameter
        self.ax_sslider = plt.subplot(grid[4, 8:11])
        self.s_slider = Slider(self.ax_sslider, 'Spatial Sigma', 0, s_sig_maxval,
                                valinit=self.params['s_sig'], valfmt='%d')
        # slider for k_size filter parameter
        self.ax_kslider = plt.subplot(grid[5, 8:11])
        self.k_slider = Slider(self.ax_kslider, 'Kernel Size', 1, k_size_maxval,
                                valinit=k_size_inival, valfmt='%d')
        # slider for area_thr parameter
        self.ax_aslider = plt.subplot(grid[7, 8:11])
        self.a_slider = Slider(self.ax_aslider, 'Area Thresh', 0, area_thr_maxval,
                                valinit=0, valfmt='%d')
        # slider for conv_size convolution parameter
        self.ax_ckslider = plt.subplot(grid[9, 8:11])
        self.ck_slider = Slider(self.ax_ckslider, 'Conv. Kernel Size', 0, conv_size_maxval,
                                valinit=self.params['conv_size'], valfmt='%d')
        # slider for max_rad starburst parameter
        self.ax_rslider = plt.subplot(grid[10, 8:11])
        self.r_slider = Slider(self.ax_rslider, 'Max. Radius', 0, rad_maxval,
                                valinit=self.params['max_rad'], valfmt='%d')


        # disconnect default matplotlib key bindings
        manager, canvas = self.fig.canvas.manager, self.fig.canvas
        canvas.mpl_disconnect(manager.key_press_handler_id)
        # maximize display window
        manager.window.showMaximized()

        # connect callback functions
        self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.cidkey = self.fig.canvas.mpl_connect('key_press_event', self.on_key)
        self.f_slider.on_changed(self.update_frame)
        self.sp_slider.on_changed(self.sp_update)
        self.rp_slider.on_changed(self.rp_update)
        self.c_slider.on_changed(self.c_update)
        self.s_slider.on_changed(self.s_update)
        self.k_slider.on_changed(self.k_update)
        self.a_slider.on_changed(self.a_update)
        self.ck_slider.on_changed(self.ck_update)
        self.r_slider.on_changed(self.r_update)

        # get pupil and reflection patches
        self.get_eye_patches()
        # display
        self.update_display()
        # display user prompts
        print("Set Parameters")
        print("    - click on the slider or use the arrow keys to scroll between video frames")
        print("")
        print("    Pre-Processing:")
        print("    - use 'e' to toggle histogram equalization")
        print("    - set parameters for equalizationg and smoothing using the sliders")
        print("")
        print("    - use 'm' to change eye tracking method")
        print("    Threshold Method:")
        print("    - set dark and light binarization thresholds by clicking on the")
        print("    intensity histogram (right & left clicks respectively)")
        print("    - set an area threshold using the slider")
        print("    Convolution Method:")
        print("    - set a kernel size and a maximum radius using the sliders")
        print("")
        print("    Ellipse Fitting:")
        print("    - use 'h' to toggle ellipse fitting via re-sampling (Hough Transform)")
        print("")
        print("    - set the blink-detection threshold by clicking on the difference")
        print("    scatter plot or distribution plot")
        print("")
        print("    - use 'r' to re-set paramters and return to the previous window,")
        print("    'c' to confirm the current parameters and commence tracking,")
        print("    or 'esc' to quit (parameters will not be saved)")

    def on_click(self, event):
        """sets thresholds based on click"""

        # set binarization thresholds
        if event.inaxes == self.ax_hist:
            if event.button == 1: # left click
                self.dark_thr.remove()
                self.params['dark_thr'] = event.xdata
                self.dark_thr = self.ax_hist.axvline(self.params['dark_thr'],
                                                    color=[0,0,0], marker=' ',
                                                    label='Dark Threshold')
            elif event.button == 3: # right click
                self.light_thr.remove()
                self.params['light_thr'] = event.xdata
                self.light_thr = self.ax_hist.axvline(self.params['light_thr'],
                                                    color=[0.75,0.75,0.75], marker=' ',
                                                    label='Light Threshold')
            # update binarized frame
            self.b_frame, light = img.binarize(self.p_frame, self.params)
            # update eye patches
            self.get_eye_patches()

        # set blink threshold in diffs timecourse axis
        elif event.inaxes == self.ax_diffs:
            if event.button == 1:
                self.blink_thr1.remove()
                self.blink_thr2.remove()
                self.params['blink_thr'] = event.ydata
                self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'],
                                                        marker=' ', color=[0,0,0],
                                                        label='Blink Threshold')
                self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'],
                                                        marker=' ', color=[0,0,0])
        # set blink threshold in diffs distribution axis
        elif event.inaxes == self.ax_dist:
            if event.button == 1:
                self.blink_thr1.remove()
                self.blink_thr2.remove()
                self.params['blink_thr'] = event.xdata
                self.blink_thr1 = self.ax_diffs.axhline(self.params['blink_thr'],
                                                        marker=' ', color=[0,0,0],
                                                        label='Blink Threshold')
                self.blink_thr2 = self.ax_dist.axvline(self.params['blink_thr'],
                                                        marker=' ', color=[0,0,0])

        # update display
        self.update_display()

    def on_key(self, event):
        """specifies functions of pressed keys in matplotlib figure"""

        # toggle histogram equalization
        if event.key == 'e':
            self.params['eq'] = not self.params['eq']
            print("    EQ: " + str(self.params['eq']))
            self.update_pframes()
            self.update_hist()
            self.get_eye_patches()
            self.update_display()

        # toggle eye tracking method
        elif event.key == 'm':
            if self.params['method'] == 'threshold':
                self.params['method'] = 'convolve'
                print("    Method: convolve")
            elif self.params['method'] == 'convolve':
                self.params['method'] = 'threshold'
                print("    Method: threshold")
            self.get_eye_patches() # update eye patches
            self.update_display()

        # toggle ellipse fitting method
        elif event.key == 'f':
            if self.params['shape'] == 'lse_ellipse':
                self.params['shape'] = 'min_enclosing'
            elif self.params['shape'] == 'min_enclosing':
                self.params['shape'] = 'lse_ellipse'
            print("    Shape fit: " + self.params['shape'])
            self.get_eye_patches()
            self.update_display()

        # toggle Hough transform ellipse fitting
        elif event.key == 'h':
            self.params['ht_fit'] = not self.params['ht_fit']
            print("    Hough transform: " + str(self.params['ht_fit']))
            self.get_eye_patches() # update eye patches
            self.update_display()

        # change frame
        elif event.key == 'right': # move forward one frame
            if np.round(self.f_slider.val) < self.nframes-1:
                self.f_slider.set_val(self.f_slider.val+1)
        elif event.key == 'left': # move back one frame
            if np.round(self.f_slider.val) > 0:
                self.f_slider.set_val(self.f_slider.val-1)

        # continue, close, or re-set
        elif event.key == 'c': # confirm parameters and close figure
            print("Eye tracking parameters confirmed")
            self.done = True
            plt.close('all')
        elif event.key == 'escape': # quit eye tracking
            print("Eye tracking aborted")
            plt.close('all')
            sys.exit()
        elif event.key == 'r': # re-set & return to annotation
            print("Parameters re-set")
            plt.close('all')

    def update_frame(self, val):
        """updates frame and plots based on frame slider value"""

        # update frame index
        self.frame_ind = np.int(np.round(self.f_slider.val))

        # update displays
        self.frame = self.stack[self.frame_ind,:,:] # raw frame
        self.update_pframes()
        self.update_hist()
        # diffs plot
        self.f_dot.remove()
        self.f_dot = self.ax_diffs.scatter(self.frame_ind, self.diffs[self.frame_ind],
                                            s=4, color=[1,0,1])
        self.get_eye_patches()
        self.update_display()

    def sp_update(self, val):
        """updates eq_sp equalization parameter and re-applies filters"""

        self.params['eq_sp'] = self.sp_slider.val
        self.update_pframes()
        self.update_hist()
        self.get_eye_patches()
        self.update_display()

    def rp_update(self, val):
        """updates eq_rp equalization parameter and re-applies filters"""

        self.params['eq_rp'] = self.rp_slider.val
        self.update_pframes()
        self.update_hist()
        self.get_eye_patches()
        self.update_display()

    def c_update(self, val):
        """updates c_sig filter parameter and re-applies filters"""

        self.params['c_sig'] = self.c_slider.val
        self.update_pframes()
        self.get_eye_patches()
        self.update_display()

    def s_update(self, val):
        """updates s_sig filter parameter and re-applies filters"""

        self.params['s_sig'] = self.s_slider.val
        self.update_pframes()
        self.get_eye_patches()
        self.update_display()

    def k_update(self, val):
        """updates k_size filter parameter and re-applies filters"""

        self.params['k_size'] = np.int(np.round(self.k_slider.val))
        if self.params['k_size'] % 2 == 0: # kernel size must be odd
            self.params['k_size'] += 1
        self.update_pframes()
        self.get_eye_patches()
        self.update_display()

    def a_update(self, val):
        """updates dark contour area threshold and re-applies eye-tracking"""

        self.params['area_thr'] = self.a_slider.val
        self.get_eye_patches()
        self.update_display()

    def ck_update(self, val):
        """updates conv_size convolution parameter and re-applies eye-tracking"""

        self.params['conv_size'] = np.int(np.round(self.ck_slider.val))
        if self.params['conv_size'] % 2 == 0: # kernel size must be odd
            self.params['conv_size'] += 1
        self.update_pframes()
        self.get_eye_patches()
        self.update_display()

    def r_update(self, val):
        """updates max_rad and re-applies eye-tracking"""

        self.params['max_rad'] = self.r_slider.val
        self.update_pframes()
        self.get_eye_patches()
        self.update_display()

    def update_pframes(self):
        """updates the filtered frame"""

        self.p_frame = img.pre_process(self.frame, self.params)
        self.b_frame, light = img.binarize(self.p_frame, self.params)
        self.c_frame, self.center = img.square_convolve(self.p_frame, self.params['conv_size'])
        self.c_dot.remove()
        self.c_dot = self.ax_cframe.scatter(self.center[0], self.center[1], s=8, c='C0')
        self.g_frame = img.gradient(self.p_frame)
        self.edge_pts = geom.starburst(self.g_frame, self.center, self.params['max_rad'], 100)
        self.e_dots.remove()
        self.e_dots = self.ax_gframe.scatter(self.edge_pts[:,0], self.edge_pts[:,1],
                                            s=1, color='C0')

    def update_hist(self):
        """updates the intensity histogram"""

        # remove old histogram patches
        _ = [i.remove() for i in self.hist[2]]
        # update histogram
        self.hist = self.ax_hist.hist(self.p_frame.ravel(), bins=np.arange(256),
                                    color='C0', alpha=0.5, density=True)
        # set ylim to match current histogram
        self.ax_hist.set_ylim(top=self.hist[0].max())

    def update_display(self):
        """updates the display"""

        self.ax_frame.imshow(self.frame, cmap='gray')
        self.ax_pframe.imshow(self.p_frame, cmap='gray')
        self.ax_bframe.imshow(self.b_frame, cmap='gray')
        self.ax_cframe.imshow(self.c_frame, cmap='gray')
        self.ax_gframe.imshow(self.g_frame, cmap='gray')
        self.fig.canvas.draw()

    def get_eye_patches(self):
        """performs eye tracking on the current frame and updates pupil and
        reflection patches"""

        if self.params['method'] == 'threshold':
            pupil, cr = main.itrack_threshold(self.frame, self.params)
        elif self.params['method'] == 'convolve':
            pupil, cr = main.itrack_convolve(self.frame, self.params)

        # clear any old patches
        self.ax_frame.patches.clear()
        self.ax_pframe.patches.clear()

        if not np.isnan(pupil).all(): # if fit was successful
            # parse ellipse parameters
            xy = (pupil[0], pupil[1])
            width, height, theta = pupil[2] * 2, pupil[3] * 2, pupil[4]
            # create patches
            self.pupil = Ellipse(xy, width, height, angle=theta,
                                lw=2, ec='C0', fill=False)
            self.pupil2 = Ellipse(xy, width, height, angle=theta,
                                lw=2, ec='C0', fill=False)
            # add patches to appropriate axes
            self.ax_frame.add_patch(self.pupil)
            self.ax_pframe.add_patch(self.pupil2)

        # check that tracking was successful
        if not np.isnan(cr).all(): # if fit was successful
            # parse ellipse parameters
            xy = (cr[0], cr[1])
            width, height, theta = cr[2] * 2, cr[3] * 3, cr[4]
            # create patches
            self.cr = Ellipse(xy, width, height, angle=theta,
                              lw=2, ec='C1', fill=False)
            self.cr2 = Ellipse(xy, width, height, angle=theta,
                              lw=2, ec='C1', fill=False)
            # add patches to appropriate axes
            self.ax_frame.add_patch(self.cr)
            self.ax_pframe.add_patch(self.cr2)
class CompareAnimation:
    """
    Launch two parallel animations of runs, so the user can 
    easily compare the structures. Example:
        R1 = ReadRun("fake/path/run_1/")
        R2 = ReadRun("fake/path/run_2/")
        new_anim=CompareAnimation(R1.S,R2.S)
        new_anim.launch()
        plt.show()
    """
    def __init__(self,snaplist1,snaplist2, symbol="bo", dt = None, markersize=2, **kwargs):
        if snaplist1[0].t < snaplist2[0].t:
            S = deepcopy(snaplist1[0])
            S.t = snaplist2[0].t
            snaplist1.reverse()
            snaplist1.append(S)
            snaplist1.reverse()
        if snaplist2[0].t < snaplist1[0].t:
            S = deepcopy(snaplist2[0])
            S.t = snaplist1[0].t
            snaplist2.reverse()
            snaplist2.append(S)
            snaplist2.reverse()
        self.snaplists = [ snaplist1, snaplist2 ]
        self.times = [ [s.t for s in snapl] for snapl in self.snaplists ]
        #self.nsnap=len(snaplist)
        self.n = [ 0, 0 ]
        self.t=0.
        self.tmax = max(self.snaplists[0][-1].t, self.snaplists[1][-1].t)
        self.dt = self.tmax/200. if dt is None else dt
        self.delay=1
        self.symbol= symbol
        self.markersize= markersize
        self.pause_switch=True
        self.BackgroundColor='white'
        self.kwargs=kwargs
    def create_frame(self):
        self.fig = plt.figure(figsize=(15,10))
        print "fig created"
        self.ax = []
        for i in range(2):
            ax = self.fig.add_subplot(121+i, projection='3d',
                                       adjustable='box',
                                       axisbg=self.BackgroundColor) 
            ax.set_aspect('equal')
            plt.tight_layout()
            ax.set_xlabel("x (pc)")
            ax.set_ylabel ("y (pc)")
            ax.set_zlabel ("z (pc)")
            self.ax.append(ax)
        X,Y,Z = [],[],[]
        for snapl in self.snaplists:
            X.append( snapl[0].x )
            Y.append( snapl[0].y )
            Z.append( snapl[0].z )
        max_range = np.array([X[0].max()-X[0].min(), 
                              Y[0].max()-Y[0].min(), 
                              Z[0].max()-Z[0].min()]).max() / 3.0
        mean_x = X[0].mean(); mean_y = Y[0].mean(); mean_z = Z[0].mean()
        self.fig.subplots_adjust(bottom=0.06)#, left=0.1)
        self.line, self.canvas = [],[]
        for (x,y,z,ax) in zip(X,Y,Z,self.ax):
            self.line.append(ax.plot(x, y, z, self.symbol, markersize=self.markersize,
                                    **self.kwargs )[0])
            self.canvas.append(ax.figure.canvas)
            ax.set_xlim(mean_x - max_range, mean_x + max_range)
            ax.set_ylim(mean_y - max_range, mean_y + max_range)
            ax.set_zlim(mean_z - max_range, mean_z + max_range)
        ax_pauseB=plt.axes([0.04, 0.02, 0.06, 0.025])
        self.pauseB=Button(ax_pauseB,'Play')
        self.pauseB.on_clicked(self.Pause_button)
        slider_ax = plt.axes([0.18, 0.02, 0.73, 0.025])
        self.slider_time = Slider(slider_ax, 
                                  "Time", 
                                  self.snaplists[0][0].t, 
                                  self.snaplists[0][-1].t, 
                                  valinit = self.snaplists[0][0].t, 
                                  color = '#AAAAAA')
        self.slider_time.on_changed(self.slider_time_update)
    def update_lines(self):
        nsnaps = [0,0]        
        for j,snaplist,time in zip(range(2),self.snaplists,self.times):
            ind = np.nonzero(time < self.t)[0]
            nsnaps[j] = max(ind) if len(ind)!=0 else 0
        for (line,snaplist,n) in zip(self.line, self.snaplists, nsnaps):
            line.set_data(snaplist[n].x, snaplist[n].y)
            line.set_3d_properties(snaplist[n].z)
        for canv in self.canvas:
            canv.draw()
    def timer_update(self,lines):
        if not self.pause_switch:
            self.t = self.t+self.dt
            if self.t > self.tmax:
                self.t = self.t - self.tmax
            self.update_lines()
            self.slider_time.set_val(self.t)        
    def Pause_button(self,event):
        self.pause_switch = not self.pause_switch
    def slider_time_update(self,val):
        self.t = val
        self.update_lines()
    def launch(self):
        self.create_frame()
        self.timer=self.fig.canvas.new_timer(interval=self.delay)
        args=[self.line]
        # We tell the timer to call the update function every 100ms
        self.timer.add_callback(self.timer_update,*args)
        self.timer.start()
Exemple #43
0
class UI:
    def __init__(self, ss_plotdata, min_limit, max_limit, coloring, ppf):
        self.ss_plotdata = ss_plotdata
        self.n = self.ss_plotdata.n
        self.ppf = ppf
        self.min_limit = min_limit
        self.max_limit = max_limit
        self.azimuth = -65
        self.elevation = 23
        self.fig = plt.figure()
        self.fig.set_size_inches(17, 9)
        gs = gridspec.GridSpec(2,
                               3,
                               width_ratios=[1, 1, 0.1],
                               height_ratios=[1, 0.05],
                               left=0.05)

        self.fig.suptitle(ss_plotdata.title)

        self.entry_ax = plt.subplot(gs[0], projection='3d')
        self.entry_ax.view_init(elev=self.elevation, azim=self.azimuth)
        self.exit_ax = plt.subplot(gs[1], projection='3d')
        self.exit_ax.view_init(elev=self.elevation, azim=self.azimuth)
        # setup colorbar
        self.cb_ax = plt.subplot(gs[2])

        if coloring == Coloring.DISCRETE_MONTHS or coloring == Coloring.DISCRETE_MONTHS_SPLIT_MARKERS or coloring == Coloring.DISCRETE_MONTHS_POLYGONS:
            cbar = self.fig.colorbar(ss_plotdata.mapping,
                                     cax=self.cb_ax,
                                     label=ss_plotdata.colorbar_label,
                                     ticks=np.arange(1.5, 13.5, 1))
            cbar.ax.set_yticklabels(calendar.month_abbr[1:])
        else:
            self.fig.colorbar(ss_plotdata.mapping,
                              cax=self.cb_ax,
                              label=ss_plotdata.colorbar_label)

        self.slider_ax = plt.subplot(gs[1, :])

        self.time_slider = Slider(self.slider_ax,
                                  "Time",
                                  0,
                                  self.n,
                                  valinit=ss_plotdata.n,
                                  valstep=1,
                                  valfmt=("%0.0f " + ss_plotdata.lag_units))

        self.timestep = 0
        self.time_slider.on_changed(self.update_time)

    def draw(self, offset):
        self.ss_plotdata.update_data(self.entry_ax, self.exit_ax, offset)
        self.set_limit()

    def update_time(self, val):
        self.timestep = int(self.time_slider.val)
        self.draw(self.timestep)
        self.fig.canvas.draw_idle()

    def set_limit(self):
        set_all_limits(self.exit_ax, self.min_limit, self.max_limit)
        set_all_limits(self.entry_ax, self.min_limit, self.max_limit)

    def animate(self, k):
        self.timestep += self.ppf
        self.timestep %= self.n
        print('\r{0:.2f}%'.format(100 * self.timestep / self.n), end='')
        self.time_slider.set_val(self.timestep)
        return []

    def render_animation_to_file(self, outfile):
        print("Rendering....")
        self.ani = animation.FuncAnimation(
            self.fig,
            self.animate,
            self.ppf * np.arange(0, (self.n + 1) // self.ppf),
            interval=20,
            repeat=False,
            blit=True)
        self.ani.save(outfile, writer="ffmpeg")

    def render_image_to_file(self, outfile):
        self.draw(self.n)
        plt.draw()
        self.fig.savefig(outfile)

    def show_animation(self):
        self.ani = animation.FuncAnimation(
            self.fig,
            self.animate,
            self.ppf * np.arange(0, (self.n + 1) // self.ppf),
            interval=20,
            repeat=True,
            blit=True)
        plt.show()

    def render_image(self, offset):
        pass
Exemple #44
0
class PlotFrame(wx.Frame):
    """
        PlotFrame is a custom wxPython frame to hold the panel with a
        Figure and WxAgg backend canvas for matplotlib plots or other
        figures.  In this frame:

        self is an instance of a wxFrame;
        axes is an instance of MPL Axes;
        fig is an instance of MPL Figure;
        panel is an instance of wxPanel, used for the main panel, to hold
        canvas, an instance of MPL FigureCanvasWxAgg.
    """

    # Main function to set everything up when the frame is created
    def __init__(self, title, pos, size):
        """
           This will be executed when an instance of PlotFrame is created.
           It is the place to define any globals as "self.<name>".
        """
        wx.Frame.__init__(self, None, wx.ID_ANY, title, pos, size)

        if len(sys.argv) < 2:
            self.filename = ""
        else:
            self.filename = sys.argv[1]

        # set some Boolean flags
        self.STOP = False
        self.data_loaded = False
        self.reverse_play = False

        self.step = 1

        #    Make the main Matplotlib panel for plots
        self.create_main_panel()  # creates canvas and contents

        # Then add wxPython widgets below the MPL canvas
        # Layout with box sizers

        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.EXPAND)
        self.sizer.AddSpacer(10)
        self.sizer.Add(self.toolbar, 0, wx.EXPAND)
        self.sizer.AddSpacer(10)

        #    Make the control panel with a row of buttons
        self.create_button_bar()
        self.sizer.Add(self.button_bar_sizer, 0, flag=wx.ALIGN_CENTER | wx.TOP)

        #    Make a Status Bar
        self.statusbar = self.CreateStatusBar()
        self.sizer.Add(self.statusbar, 0, wx.EXPAND)

        self.SetStatusText("Frame created ...")

        # -------------------------------------------------------
        #	       set up the Menu Bar
        # -------------------------------------------------------
        menuBar = wx.MenuBar()

        menuFile = wx.Menu()  # File menu
        menuFile.Append(1, "&Open", "Filename(s) or wildcard list to plot")
        menuFile.Append(3, "Save", "Save plot as a PNG image")
        menuFile.AppendSeparator()
        menuFile.Append(10, "E&xit")
        menuBar.Append(menuFile, "&File")

        menuHelp = wx.Menu()  # Help menu
        menuHelp.Append(11, "&About Netview")
        menuHelp.Append(12, "&Usage and Help")
        menuHelp.Append(13, "Program &Info")

        menuBar.Append(menuHelp, "&Help")
        self.SetMenuBar(menuBar)

        self.panel.SetSizer(self.sizer)
        self.sizer.Fit(self)

        # -------------------------------------------------------
        #      Bind the menu items to functions
        # -------------------------------------------------------

        self.Bind(wx.EVT_MENU, self.OnOpen, id=1)
        self.Bind(wx.EVT_MENU, self.OnSave, id=3)
        self.Bind(wx.EVT_MENU, self.OnQuit, id=10)
        self.Bind(wx.EVT_MENU, self.OnAbout, id=11)
        self.Bind(wx.EVT_MENU, self.OnUsage, id=12)
        self.Bind(wx.EVT_MENU, self.OnInfo, id=13)

        # methods defined below to get and plot the data
        # Normally do the plot on request, and not here
        # self.get_data_params()
        # self.init_plot()
        # self.get_xyt_data()
        # plot_data()

# ---------- end of __init__ ----------------------------

# -------------------------------------------------------
#   Function to make the main Matplotlib panel for plots
# -------------------------------------------------------

    def create_main_panel(self):
        """ create_main_panel creates the main mpl panel with instances of:
             * mpl Canvas 
             * mpl Figure 
             * mpl Figure
             * mpl Axes with subplot
             * mpl Widget class Sliders and Button
             * mpl navigation toolbar
           self.axes is the instance of MPL Axes, and is where it all happens
        """

        self.panel = wx.Panel(self)

        # Create the mpl Figure and FigCanvas objects.
        # 3.5 x 5 inches, 100 dots-per-inch
        #
        self.dpi = 100
        self.fig = Figure((3.5, 5.0), dpi=self.dpi)
        self.canvas = FigCanvas(self.panel, wx.ID_ANY, self.fig)

        # Since we have only one plot, we could use add_axes
        # instead of add_subplot, but then the subplot
        # configuration tool in the navigation toolbar wouldn't work.

        self.axes = self.fig.add_subplot(111)
        # (111) == (1,1,1) --> row 1, col 1, Figure 1)
        # self.axes.set_title("View from: "+self.filename)

        # Now create some sliders below the plot after making room
        self.fig.subplots_adjust(left=0.1, bottom=0.20)

        self.axtmin = self.fig.add_axes([0.2, 0.10, 0.5, 0.03])
        self.axtmax = self.fig.add_axes([0.2, 0.05, 0.5, 0.03])

        self.stmin = Slider(self.axtmin, 't_min:', 0.0, 1.0, valinit=0.0)
        self.stmax = Slider(self.axtmax, 't_max:', 0.0, 1.0, valinit=1.0)
        self.stmin.on_changed(self.update_trange)
        self.stmax.on_changed(self.update_trange)

        self.axbutton = self.fig.add_axes([0.8, 0.07, 0.1, 0.07])
        self.reset_button = Button(self.axbutton, 'Reset')
        self.reset_button.color = 'skyblue'
        self.reset_button.hovercolor = 'lightblue'
        self.reset_button.on_clicked(self.reset_trange)

        # Create the navigation toolbar, tied to the canvas

        self.toolbar = NavigationToolbar(self.canvas)

    def update_trange(self, event):
        self.t_min = self.stmin.val
        self.t_max = self.stmax.val
        # print(self.t_min, self.t_max)

    def reset_trange(self, event):
        self.stmin.reset()
        self.stmax.reset()

    def create_button_bar(self):
        """
	create_button_bar makes a control panel bar with buttons and
	toggles for

	New Data - Play - STOP - Single Step - Forward/Back - Normal/Fast

	It does not create a Panel container, but simply creates Button
	objects with bindings, and adds  them to a horizontal BoxSizer
	self.button_bar_sizer.	This is added to the PlotFrame vertical
	BoxSizer, after the MPL canvas, during initialization of the frame.

	"""
        rewind_button = wx.Button(self.panel, -1, "New Data")
        self.Bind(wx.EVT_BUTTON, self.OnRewind, rewind_button)

        replot_button = wx.Button(self.panel, -1, "Play")
        self.Bind(wx.EVT_BUTTON, self.OnReplot, replot_button)

        sstep_button = wx.Button(self.panel, -1, "Single Step")
        self.Bind(wx.EVT_BUTTON, self.OnSstep, sstep_button)

        stop_button = wx.Button(self.panel, -1, "STOP")
        self.Bind(wx.EVT_BUTTON, self.OnStop, stop_button)

        # The toggle buttons need to be globally accessible

        self.forward_toggle = wx.ToggleButton(self.panel, -1, "Forward")
        self.forward_toggle.SetValue(True)
        self.forward_toggle.SetLabel("Forward")
        self.Bind(wx.EVT_TOGGLEBUTTON, self.OnForward, self.forward_toggle)

        self.fast_toggle = wx.ToggleButton(self.panel, -1, " Normal ")
        self.fast_toggle.SetValue(True)
        self.fast_toggle.SetLabel(" Normal ")
        self.Bind(wx.EVT_TOGGLEBUTTON, self.OnFast, self.fast_toggle)

        # Set button colors to some simple colors that are likely
        # to be independent on X11 color definitions.  Some nice
        # bit maps (from a media player skin?) should be used
        # or the buttons and toggle state colors in OnFast() below

        rewind_button.SetBackgroundColour('skyblue')
        replot_button.SetBackgroundColour('skyblue')
        sstep_button.SetBackgroundColour('skyblue')
        stop_button.SetBackgroundColour('skyblue')
        self.forward_toggle.SetForegroundColour('black')
        self.forward_toggle.SetBackgroundColour('yellow')
        self.fast_toggle.SetForegroundColour('black')
        self.fast_toggle.SetBackgroundColour('yellow')
        self.button_bar_sizer = wx.BoxSizer(wx.HORIZONTAL)
        flags = wx.ALIGN_CENTER | wx.ALL
        self.button_bar_sizer.Add(rewind_button, 0, border=3, flag=flags)
        self.button_bar_sizer.Add(replot_button, 0, border=3, flag=flags)
        self.button_bar_sizer.Add(sstep_button, 0, border=3, flag=flags)
        self.button_bar_sizer.Add(stop_button, 0, border=3, flag=flags)
        self.button_bar_sizer.Add(self.forward_toggle, 0, border=3, flag=flags)
        self.button_bar_sizer.Add(self.fast_toggle, 0, border=3, flag=flags)

    # -------------------------------------------------------
    #	Functions to generate or read (x,y) data and plot it
    # -------------------------------------------------------

    def get_data_params(self):
        #  These parameters  would normally be provided in a file header,
        # past as arguments in a function, or from other file information
        #  Next version will bring up a dialog for dt NX NY if no file header

        # Here check to see if a filename should be entered from File/Open
        # self.filename = 'Ex_net_Vm_0001.txt'
        if len(self.filename) == 0:
            # fake a button press of File/Open
            self.OnOpen(wx.EVT_BUTTON)

        # should check here if file exists as specified [path]/filename

        # assume it is a bzip2 compressed file
        try:
            fp = bz2.BZ2File(self.filename)
            line = fp.readline()
        except IOError:
            # then assume plain text
            fp = open(self.filename)
            line = fp.readline()

        fp.close()

        # check if first line is a header line starting with '#'
        header = line.split()
        if header[0][0] == "#":
            self.Ntimes = int(header[1])
            self.t_min = float(header[2])
            self.dt = float(header[3])
            self.NX = int(header[4])
            self.NY = int(header[5])
        else:
            pdentry = self.ParamEntryDialog()
            if pdentry.ShowModal() == wx.ID_OK:
                self.Ntimes = int(pdentry.Ntimes_dialog.entry.GetValue())
                self.t_min = float(pdentry.tmin_dialog.entry.GetValue())
                self.dt = float(pdentry.dt_dialog.entry.GetValue())
                self.NX = int(pdentry.NX_dialog.entry.GetValue())
                self.NY = int(pdentry.NY_dialog.entry.GetValue())
                print 'Ntimes = ', self.Ntimes, ' t_min = ', self.t_min
                print 'NX = ', self.NX, ' NY = ', self.NY
            pdentry.Destroy()
        self.t_max = (self.Ntimes - 1) * self.dt
        # reset slider max and min
        self.stmin.valmax = self.t_max
        self.stmin.valinit = self.t_min
        self.stmax.valmax = self.t_max
        self.stmax.valinit = self.t_max
        self.stmax.set_val(self.t_max)
        self.stmin.reset()
        self.stmax.reset()
        fp.close()

    def init_plot(self):
        ''' 
        init_plot creates the initial plot display. A normal MPL plot
	would be created here with a command "self.axes.plot(x, y)" in
	order to create a plot of points in the x and y arrays on the
	Axes subplot.  Here, we create an AxesImage instance with
	imshow(), instead.  The initial image is a blank one of the
	proper dimensions, filled with zeroes.

        '''
        self.t_max = (self.Ntimes - 1) * self.dt
        self.axes.set_title("View of " + self.filename)
        # Note that NumPy array (row, col) = image (y, x)
        data0 = np.zeros((self.NY, self.NX))

        # Define a 'cold' to 'hot' color scale based in GENESIS 2 'hot'
        hotcolors = [
            '#000032', '#00003c', '#000046', '#000050', '#00005a', '#000064',
            '#00006e', '#000078', '#000082', '#00008c', '#000096', '#0000a0',
            '#0000aa', '#0000b4', '#0000be', '#0000c8', '#0000d2', '#0000dc',
            '#0000e6', '#0000f0', '#0000fa', '#0000ff', '#000af6', '#0014ec',
            '#001ee2', '#0028d8', '#0032ce', '#003cc4', '#0046ba', '#0050b0',
            '#005aa6', '#00649c', '#006e92', '#007888', '#00827e', '#008c74',
            '#00966a', '#00a060', '#00aa56', '#00b44c', '#00be42', '#00c838',
            '#00d22e', '#00dc24', '#00e61a', '#00f010', '#00fa06', '#00ff00',
            '#0af600', '#14ec00', '#1ee200', '#28d800', '#32ce00', '#3cc400',
            '#46ba00', '#50b000', '#5aa600', '#649c00', '#6e9200', '#788800',
            '#827e00', '#8c7400', '#966a00', '#a06000', '#aa5600', '#b44c00',
            '#be4200', '#c83800', '#d22e00', '#dc2400', '#e61a00', '#f01000',
            '#fa0600', '#ff0000', '#ff0a00', '#ff1400', '#ff1e00', '#ff2800',
            '#ff3200', '#ff3c00', '#ff4600', '#ff5000', '#ff5a00', '#ff6400',
            '#ff6e00', '#ff7800', '#ff8200', '#ff8c00', '#ff9600', '#ffa000',
            '#ffaa00', '#ffb400', '#ffbe00', '#ffc800', '#ffd200', '#ffdc00',
            '#ffe600', '#fff000', '#fffa00', '#ffff00', '#ffff0a', '#ffff14',
            '#ffff1e', '#ffff28', '#ffff32', '#ffff3c', '#ffff46', '#ffff50',
            '#ffff5a', '#ffff64', '#ffff6e', '#ffff78', '#ffff82', '#ffff8c',
            '#ffff96', '#ffffa0', '#ffffaa', '#ffffb4', '#ffffbe', '#ffffc8',
            '#ffffd2', '#ffffdc', '#ffffe6', '#fffff0'
        ]

        cmap = matplotlib.colors.ListedColormap(hotcolors)

        self.im = self.axes.imshow(data0, cmap=cmap, origin='lower')

        # http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
        # shows examples to use as a 'cold' to 'hot' mapping of value to color
        # cm.jet, cm.gnuplot and cm.afmhot are good choices, but are unlike G2 'hot'

        self.im.cmap = cmap

        # Not sure how to properly add a colorbar
        # self.cb = self.fig.colorbar(self.im, orientation='vertical')

        # refresh the canvas
        self.canvas.draw()

    def get_xyt_data(self):
        # Create scaled (0-1) luminance(x,y) array from ascii G-2 disk_out file
        # get the data to plot from the specified filename
        # Note that NumPy loadtxt transparently deals with bz2 compression
        self.SetStatusText('Data loading - please wait ....')
        rawdata = np.loadtxt(self.filename)
        # Note the difference between NumPy [row, col] order and network
        # x-y grid (x, y) = (col, row). We want a NumPy NY x NX, not
        # NX x NY, array to be used by the AxesImage object.

        xydata = np.resize(rawdata, (self.Ntimes, self.NY, self.NX))
        # imshow expects the data to be scaled to range 0-1.
        Vmin = xydata.min()
        Vmax = xydata.max()
        self.ldata = (xydata - Vmin) / (Vmax - Vmin)
        self.data_loaded = True
        self.SetStatusText('Data has been loaded - click Play')

    def plot_data(self):
        ''' plot_data() shows successive frames of the data that was loaded
            into the ldata array.  Creating a new self.im AxesImage instance
            for each frame is extremely slow, so the set_data method of
            AxesImage is used to load new data into the existing self.im for
            each frame.  Normally 'self.canvas.draw()' would be used to
            display a frame, but redrawing the entire canvas, redraws the
            axes, labels, sliders, buttons, or anything else on the canvas.
            This uses a method taken from an example in Ch 7, p. 192
            Matplotlib for Python developers, with draw_artist() and blit()
            redraw only the part that was changed.

        '''
        if self.data_loaded == False:
            # bring up a warning dialog
            msg = """
            Data for plotting has not been loaded!
            Please enter the file to plot with File/Open, unless
            it was already specified, and then click on 'New Data'
            to load the data to play back, before clicking 'Play'.
            """
            wx.MessageBox(msg, "Plot Warning", wx.OK | wx.ICON_ERROR, self)
            return

        # set color limits
        self.im.set_clim(0.0, 1.0)
        self.im.set_interpolation('nearest')
        # 'None' is is slightly faster, but not implemented for MPL ver < 1.1
        # self.im.set_interpolation('None')

        # do an initial draw, then save the empty figure axes
        self.canvas.draw()

        # self.bg = self.canvas.copy_from_bbox(self.axes.bbox)
        # However the save and restore is only  needed if I change
        # axes legends, etc.  The draw_artist(artist), and blit
        # are much faster than canvas.draw() and are sufficient.

        print 'system time (seconds) = ', time.time()

        # round frame_min down and frame_max up for the time window
        frame_min = int(self.t_min / self.dt)
        frame_max = min(int(self.t_max / self.dt) + 1, self.Ntimes)
        frame_step = self.step

        # Displaying simulation time to the status bar is much faster
        # than updating a slider progress bar, but location isn't optimum.
        # The check for the STOP button doesn't work because the button
        # click is not registered until this function exits.

        # check to see if self.reverse_play == True
        # then interchange frame_min, frame_max, and use negative step
        if self.reverse_play == True:
            frame_min = min(int(self.t_max / self.dt) + 1, self.Ntimes) - 1
            frame_max = int(self.t_min / self.dt) - 1
            frame_step = -self.step
        for frame_num in range(frame_min, frame_max, frame_step):
            self.SetStatusText('time: ' + str(frame_num * self.dt))
            if self.STOP == True:
                self.t_min = frame_num * self.dt
                # set t_min slider ?
                self.STOP = False
                break
            self.im.set_data(self.ldata[frame_num])
            self.axes.draw_artist(self.im)
            self.canvas.blit(self.axes.bbox)

        print 'system time (seconds) = ', time.time()

    #  ------------------------------------------------------------------
    #	Define the classes and functions for getting parameter values
    #  --------------------------------------------------------------

    class ParamEntryDialog(wx.Dialog):
        def __init__(self):
            wx.Dialog.__init__(self, None, wx.ID_ANY)
            self.SetSize((250, 200))
            self.SetTitle('Enter Data File Parameters')
            vbox = wx.BoxSizer(wx.VERTICAL)
            self.Ntimes_dialog = XDialog(self)
            self.Ntimes_dialog.entry_label.SetLabel('Number of entries')
            self.Ntimes_dialog.entry.ChangeValue(str(2501))
            self.tmin_dialog = XDialog(self)
            self.tmin_dialog.entry_label.SetLabel('Start time (sec)')
            self.tmin_dialog.entry.ChangeValue(str(0.0))

            self.dt_dialog = XDialog(self)
            self.dt_dialog.entry_label.SetLabel('Output time step (sec)')
            self.dt_dialog.entry.ChangeValue(str(0.0002))

            self.NX_dialog = XDialog(self)
            self.NX_dialog.entry_label.SetLabel('Number of cells on x-axis')
            self.NX_dialog.entry.ChangeValue(str(32))
            self.NY_dialog = XDialog(self)
            self.NY_dialog.entry_label.SetLabel('Number of cells on y-axis')
            self.NY_dialog.entry.ChangeValue(str(32))

            vbox.Add(self.Ntimes_dialog, 0, wx.EXPAND | wx.ALL, border=5)
            vbox.Add(self.tmin_dialog, 0, wx.EXPAND | wx.ALL, border=5)
            vbox.Add(self.dt_dialog, 0, wx.EXPAND | wx.ALL, border=5)
            vbox.Add(self.NX_dialog, 0, wx.EXPAND | wx.ALL, border=5)
            vbox.Add(self.NY_dialog, 0, wx.EXPAND | wx.ALL, border=5)

            okButton = wx.Button(self, wx.ID_OK, 'Ok')
            # vbox.Add(okButton,flag=wx.ALIGN_CENTER|wx.TOP|wx.BOTTOM, border=10)
            vbox.Add(okButton, flag=wx.ALIGN_CENTER, border=10)

            self.SetSizer(vbox)
            self.SetSizerAndFit(vbox)

    #  ------------------------------------------------------------------
    #	Define the functions executed on menu choices
    #  ---------------------------------------------------------------

    def OnQuit(self, event):
        self.Close()

    def OnSave(self, event):
        file_choices = "PNG (*.png)|*.png"
        dlg = wx.FileDialog(self,
                            message="Save plot as...",
                            defaultDir=os.getcwd(),
                            defaultFile="plot.png",
                            wildcard=file_choices,
                            style=wx.SAVE)

        if dlg.ShowModal() == wx.ID_OK:
            path = dlg.GetPath()
            self.canvas.print_figure(path, dpi=self.dpi)
            # self.flash_status_message("Saved to %s" % path)

    def OnAbout(self, event):
        msg = """

                      G-3 Netview ver. 1.7

Netview is a stand-alone Python application for viewing
the output of GENESIS 2 and 3 network simulations.
It is intended to replace GENESIS 2 SLI scripts that use the
XODUS 'xview' widget.

The design and operation is based on the G3Plot application
for creating 2D plots of y(t) or y(x) from data files.
Unlike G3Plot, the image created with Netview is an animated
representation of a rectangular network with colored squares
used to indicate the value of some variable at that position
and time.  Typically, this would be the membrane potenial of
a cell soma, or a synaptic current in a dendrite segment.

Help/Usage gives HTML help for using Netview.
This is the main Help page.

Help/Program Info provides some information about the
objects and functions, and the wxPython and matplotlib
classes used here.

Dave Beeman, August 2012
	"""
        dlg = wx.MessageDialog(self, msg, "About G-3 Netview",
                               wx.OK | wx.ICON_QUESTION)
        dlg.ShowModal()
        dlg.Destroy()

    def OnOpen(self, event):
        dlg = wx.TextEntryDialog(self,
                                 "File with x,y data to plot",
                                 "File Open",
                                 self.filename,
                                 style=wx.OK | wx.CANCEL)
        if dlg.ShowModal() == wx.ID_OK:
            self.filename = dlg.GetValue()
            # A new filename has been entered, but the data has not been read
            self.data_loaded = False
            # print "You entered: %s" % self.filename
        dlg.Destroy()

    #  This starts with the long string of HTML to display
    class UsageFrame(wx.Frame):
        text = """
<HTML>
<HEAD></HEAD>
<BODY BGCOLOR="#D6E7F7">

<CENTER><H1>Using G-3 Netview</H1></CENTER>

<H2>Introduction and Quick Start</H2>

<p>Netview is a stand-alone Python application for viewing the output of
GENESIS 2 and 3 network simulations.  It is intended to replace GENESIS 2
SLI scripts that use the XODUS 'xview' widget.</p>

<p>The design and operation is based on the G3Plot application for creating 2D
plots of y(t) or y(x) from data files.  As with G3Plot, the main class
PlotFrame uses a basic wxPython frame to embed a matplotlib figure for
plotting.  It defines some basic menu items and a control panel of buttons
and toggles, each with bindings to a function to execute on a mouse click.</p>

<p>Unlike G3Plot, the image created with Netview is an animated
representation of a rectangular network with colored squares
used to indicate the value of some variable at that position
and time.  Typically, this would be the membrane potenial of
a cell soma, or a synaptic current in a dendrite segment.</p>

<h2>Usage</h2>

<p>The Menu Bar has <em>File/Open</em>, <em>File/Save</em>, and
<em>File/Exit</em> choices.  The Help Menu choices <em>About</em> and
<em>Usage</em> give further information.  The <em>Program Info</em>
selection shows code documentation that is contained in some of the main
function <em>docstrings</em>.</p>

<p>After starting the <em>netview</em> program, enter a data file name
in the dialog for File/Open, unless the filename was given as a
command line argument.  Then click on <strong>New Data</strong> to load the new
data and initialize the plot.  When the plot is cleared to black,
press <strong>Play</strong>.</p>

<p>The file types recognized are plain text or text files compressed with
bzip2.  The expected data format is one line for each output time step,
with each line having the membrane potential value of each cell in the net.
No time value should be given on the line.  In order to properly display
the data, netview needs some additional information about the network and
the data.  This can optionally be contained in a header line that precedes
the data.  If a header is not detected, a dialog will appear asking for the
needed parameters.</p>

<p>It is assumed that the cells are arranged on a NX x NY grid, numbered
from 0 (bottom left corner) to NX*NY - 1 (upper right corner).
In order to provide this information to netview, the data file should
begin with a header line of the form:</p>

<pre>
    #optional_RUNID_string Ntimes start_time dt NX NY SEP_X SEP_Y x0 y0 z0
</pre>

<p>The line must start with &quot;#&quot; and can optionally be followed immediately by any
string.  Typically this is some identification string generated by the
simulation run.  The following parameters, separated by blanks or any
whitespace, are:</p>

<ul>
<li>Ntimes - the number of lines in the file, exclusive of the header</li>
<li>start_time - the simulation time for the first data line (default 0.0)</li>
<li>dt - the time step used for output</li>
<li>NX, NY - the integer dimensions of the network</li>
<li>SEP_X, SEP_Y - the x,y distances between cells (optional)</li>
<li>x0, y0, z0 - the location of the compartment (data source) relative to the
cell origin</li>
</ul>

<p>The RUNID string and the last five parameters are not read or used
by netview.  These are available for other data analysis tools that
need a RUNID and the location of each source.</p>

<p>The slider bars can be used to set a time window for display, and the
<strong>Reset</strong> button can set t_min and t_max back to the defaults.
Use the <strong>Forward/Back</strong> toggle to reverse direction of
<strong>Play</strong>, and the <strong>Normal/Fast</strong> toggle to show
every tenth frame.</p> <p>The <strong>Single Step</strong> button can be
used to advance a single step at a time (or 10, if in 'Fast' mode).</p>

<p>The <strong>STOP</strong> button is currently not implemented</p>
<p>To plot different data, enter a new filename with <strong>File/Open</strong> and
repeat with <strong>New Data</strong> and <strong>Play</strong>.</p>

<HR>
</BODY>
</HTML>
        """

        def __init__(self, parent):
            wx.Frame.__init__(self,
                              parent,
                              -1,
                              "Usage and Help",
                              size=(640, 600),
                              pos=(400, 100))
            html = wx.html.HtmlWindow(self)
            html.SetPage(self.text)
            panel = wx.Panel(self, -1)
            button = wx.Button(panel, wx.ID_OK, "Close")
            self.Bind(wx.EVT_BUTTON, self.OnCloseMe, button)
            sizer = wx.BoxSizer(wx.VERTICAL)
            sizer.Add(html, 1, wx.EXPAND | wx.ALL, 5)
            sizer.Add(panel, 0, wx.ALIGN_CENTER | wx.ALL, 5)
            self.SetSizer(sizer)
            self.Layout()

        def OnCloseMe(self, event):
            self.Close(True)

        # ----------- end of class UsageFrame ---------------

    def OnUsage(self, event):
        usagewin = self.UsageFrame(self)
        usagewin.Show(True)

    def OnInfo(self, event):
        msg = "Program information for PlotFrame obtained from docstrings:"
        msg += "\n" + self.__doc__ + "\n" + self.create_main_panel.__doc__
        msg += self.create_button_bar.__doc__
        msg += self.init_plot.__doc__
        msg += self.plot_data.__doc__
        dlg = wx.lib.dialogs.ScrolledMessageDialog(self, msg,
                                                   "PlotFrame Documentation")
        dlg.ShowModal()

    #  ---------------------------------------------------------------
    #	Define the functions executed on control button click
    #  ---------------------------------------------------------------

    def OnRewind(self, event):
        self.get_data_params()
        self.init_plot()
        self.get_xyt_data()

    def OnReplot(self, event):
        self.plot_data()
        self.canvas.draw()

    def OnSstep(self, event):
        if self.data_loaded == False:
            # bring up a warning dialog
            msg = """
            Data for plotting has not been loaded!
            Please enter the file to plot with File/Open, unless
            it was already specified, and then click on 'New Data'
            to load the data to play back, before clicking 'Play'.
            """
            wx.MessageBox(msg, "Plot Warning", wx.OK | wx.ICON_ERROR, self)
            return

        self.t_max = min(self.t_max + self.dt, (self.Ntimes - 1) * self.dt)
        self.stmax.set_val(self.t_max)
        frame_num = int(self.t_max / self.dt)
        self.SetStatusText('time: ' + str(frame_num * self.dt))
        self.im.set_data(self.ldata[frame_num])
        self.axes.draw_artist(self.im)
        self.canvas.blit(self.axes.bbox)

    def OnStop(self, event):
        self.STOP = 'True'

    def OnForward(self, event):
        state = self.forward_toggle.GetValue()
        if state:
            self.reverse_play = False
            self.forward_toggle.SetLabel("Forward ")
            self.forward_toggle.SetForegroundColour('black')
            self.forward_toggle.SetBackgroundColour('yellow')
        else:
            self.reverse_play = True
            self.forward_toggle.SetLabel("  Back  ")
            self.forward_toggle.SetForegroundColour('red')
            self.forward_toggle.SetBackgroundColour('green')

    def OnFast(self, event):
        state = self.fast_toggle.GetValue()
        if state:
            # print state
            self.fast_toggle.SetLabel(" Normal ")
            self.fast_toggle.SetForegroundColour('black')
            self.fast_toggle.SetBackgroundColour('yellow')
            self.step = 1
        else:
            # print state
            self.fast_toggle.SetLabel("  Fast  ")
            self.fast_toggle.SetForegroundColour('red')
            self.fast_toggle.SetBackgroundColour('green')
            self.step = 10
class ParametricEQSelector:

    def __init__(self):
        self._fs = 48000
        self._nbands = 7
        self._params = []
        for i in range(0, self._nbands):
            self._params.append({'center': 0, 'resonance': 1.0/math.sqrt(2.0), 'dbgain': 0})
        self._selected_band = 0
        self._blocksize = 512
        self._nfft = int(self._blocksize / 2)
        self._impulse_response = [0] * self._blocksize
        self._freq_response_real = [0] * self._blocksize
        self._freq_response_imag = [0] * self._blocksize
        self._response = [0] * self._blocksize
        self._plot_db = True
        self._eq = yodel.filter.ParametricEQ(self._fs, self._nbands)

        self._create_plot()
        self._create_plot_controls()
        self.select_band('Band ' + str(self._selected_band+1))

    def _create_plot(self):
        self._fig, self._ax = plt.subplots()
        self._ax.set_title('Parametric Equalizer Design')
        self._ax.grid()
        plt.subplots_adjust(bottom=0.3)

        self._update_filter_response()
        self._x_axis = [i*(self._fs/2/self._nfft) for i in range(0, self._nfft)]
        self._y_axis = self._response[0:self._nfft]

        self._l_center, = self._ax.plot(self._x_axis, [0] * self._nfft, 'k')
        self._l_fr, = self._ax.plot(self._x_axis, self._y_axis, 'b')

        self._rescale_plot()

    def _create_plot_controls(self):
        self._dbrax = plt.axes([0.12, 0.05, 0.13, 0.10])
        self._dbradio = RadioButtons(self._dbrax, ('Amplitude', 'Phase'))
        self._dbradio.on_clicked(self.set_plot_style)

        self._rax = plt.axes([0.27, 0.03, 0.15, 0.20])
        bands_list = []
        for i in range(1, self._nbands+1):
            bands_list.append('Band ' + str(i))
        self._radio = RadioButtons(self._rax, tuple(bands_list))
        self._radio.on_clicked(self.select_band)

        self._sfax = plt.axes([0.6, 0.19, 0.2, 0.03])
        self._sqax = plt.axes([0.6, 0.12, 0.2, 0.03])
        self._sdbax = plt.axes([0.6, 0.05, 0.2, 0.03])
        self._fcslider = Slider(self._sfax, 'Cut-off frequency', 0, self._fs/2, valinit = self._params[self._selected_band]['center'])
        self._qslider = Slider(self._sqax, 'Q factor', 0.01, 10.0, valinit = self._params[self._selected_band]['resonance'])
        self._dbslider = Slider(self._sdbax, 'dB gain', -20.0, 20.0, valinit = self._params[self._selected_band]['dbgain'])

        self._fcslider.on_changed(self.set_center_frequency)
        self._qslider.on_changed(self.set_resonance)
        self._dbslider.on_changed(self.set_dbgain)

    def _rescale_plot(self):
        if self._plot_db:
            self._ax.set_ylim(-30, 30)
        else:
            self._ax.set_ylim(- 200, 200)
        plt.draw()

    def _plot_frequency_response(self, redraw=True):
        self._update_filter_response()
        self._y_axis = self._response[0:self._nfft]
        self._l_fr.set_ydata(self._y_axis)
        if redraw:
            plt.draw()

    def _plot_range_limits(self, redraw=True):
        self._l_center.set_ydata([0] * self._nfft)
        if redraw:
            plt.draw()

    def set_plot_style(self, style):
        if style == 'Phase':
            self._plot_db = False
        elif style == 'Amplitude':
            self._plot_db = True
        self._plot_range_limits(False)
        self._plot_frequency_response(False)
        self._rescale_plot()

    def select_band(self, band):
        idx = band.split(' ')
        self._selected_band = int(idx[1]) - 1
        self._fcslider.set_val(self._params[self._selected_band]['center'])
        self._qslider.set_val(self._params[self._selected_band]['resonance'])
        self._dbslider.set_val(self._params[self._selected_band]['dbgain'])

    def set_center_frequency(self, val):
        self._params[self._selected_band]['center'] = val
        self._set_band(self._selected_band)
        self._plot_frequency_response()
    
    def set_resonance(self, val):
        self._params[self._selected_band]['resonance'] = val
        self._set_band(self._selected_band)
        self._plot_frequency_response()

    def set_dbgain(self, val):
        self._params[self._selected_band]['dbgain'] = val
        self._set_band(self._selected_band)
        self._plot_frequency_response()

    def _set_band(self, band):
        center = self._params[band]['center']
        resonance = self._params[band]['resonance']
        dbgain = self._params[band]['dbgain']
        self._eq.set_band(band, center, resonance, dbgain)

    def _update_filter_response(self):
        self._impulse_response = impulse_response(self._eq, self._blocksize)
        self._freq_response_real, self._freq_response_imag = frequency_response(self._impulse_response)
        if self._plot_db:
            self._response = amplitude_response(self._freq_response_real, self._freq_response_imag)
        else:
            self._response = phase_response(self._freq_response_real, self._freq_response_imag)
class EnergyPlusModel(metaclass=ABCMeta):
    def __init__(self, model_file, log_dir=None, verbose=False):
        self.log_dir = log_dir
        self.model_basename = os.path.splitext(os.path.basename(model_file))[0]
        self.setup_spaces()
        self.action = 0.5 * (self.action_space.low + self.action_space.high)
        self.action_prev = self.action
        self.raw_state = None
        self.verbose = verbose
        self.timestamp_csv = None
        self.sl_episode = None

        # Progress data
        self.num_episodes = 0
        self.num_episodes_last = 0

        self.reward = None
        self.reward_mean = None

    def reset(self):
        pass

    # Parse date/time format from EnergyPlus and return datetime object with correction for 24:00 case
    def _parse_datetime(self, dstr):
        # ' MM/DD  HH:MM:SS' or 'MM/DD  HH:MM:SS'
        # Dirty hack
        if dstr[0] != ' ':
            dstr = ' ' + dstr
        # year = 2017
        year = 2013  # for CHICAGO_IL_USA TMY2-94846
        month = int(dstr[1:3])
        day = int(dstr[4:6])
        hour = int(dstr[8:10])
        minute = int(dstr[11:13])
        sec = 0
        msec = 0
        if hour == 24:
            hour = 0
            dt = datetime(year, month, day, hour, minute, sec,
                          msec) + timedelta(days=1)
        else:
            dt = datetime(year, month, day, hour, minute, sec, msec)
        return dt

    # Convert list of date/time string to list of datetime objects
    def _convert_datetime24(self, dates):
        # ' MM/DD  HH:MM:SS'
        dates_new = []
        for d in dates:
            # year = 2017
            # month = int(d[1:3])
            # day = int(d[4:6])
            # hour = int(d[8:10])
            # minute = int(d[11:13])
            # sec = 0
            # msec = 0
            # if hour == 24:
            #    hour = 0
            #    d_new = datetime(year, month, day, hour, minute, sec, msec) + dt.timedelta(days=1)
            # else:
            #    d_new = datetime(year, month, day, hour, minute, sec, msec)
            # dates_new.append(d_new)
            dates_new.append(self._parse_datetime(d))
        return dates_new

    # Generate x_pos and x_labels
    def generate_x_pos_x_labels(self, dates):
        time_delta = self._parse_datetime(dates[1]) - self._parse_datetime(
            dates[0])
        x_pos = []
        x_labels = []
        for i, d in enumerate(dates):
            dt = self._parse_datetime(d) - time_delta
            if dt.hour == 0 and dt.minute == 0:
                x_pos.append(i)
                x_labels.append(dt.strftime('%m/%d'))
        return x_pos, x_labels

    def set_action(self, action):
        # In TPRO/POP1/POP2 in baseline, action seems to be normalized to [-1.0, 1.0].
        # So it must be scaled back into action_space by the environment.
        assert action.shape == self.action_space.low.shape, 'Invalid action {}'.format(
            action)
        self.action_prev = self.action
        self.action = action
        self.action = np.clip(self.action, self.action_space.low,
                              self.action_space.high)

        # self.action_prev = self.action
        # self.action = self.action_space.low + (normalized_action + 1.) * 0.5 * (
        #         self.action_space.high - self.action_space.low)
        # self.action = np.clip(self.action, self.action_space.low, self.action_space.high)

    @abstractmethod
    def setup_spaces(self):
        pass

    # Need to handle the case that raw_state is None
    @abstractmethod
    def set_raw_state(self, raw_state):
        pass

    def get_state(self):
        return self.format_state(self.raw_state)

    @abstractmethod
    def compute_reward(self):
        pass

    @abstractmethod
    def format_state(self, raw_state):
        pass

    # --------------------------------------------------
    # Plotting staffs follow
    # --------------------------------------------------
    def plot(self, log_dir='', csv_file='', **kwargs):
        if log_dir is not '':
            if not os.path.isdir(log_dir):
                print('energyplus_model.plot: {} is not a directory'.format(
                    log_dir))
                return
            print('energyplus_plot.plot log={}'.format(log_dir))
            self.log_dir = log_dir
            self.show_progress()
        else:
            if not os.path.isfile(csv_file):
                print(
                    'energyplus_model.plot: {} is not a file'.format(csv_file))
                return
            print('energyplus_model.plot csv={}'.format(csv_file))
            self.read_episode(csv_file)
            plt.rcdefaults()
            plt.rcParams['font.size'] = 6
            plt.rcParams['lines.linewidth'] = 1.0
            plt.rcParams['legend.loc'] = 'lower right'
            self.fig = plt.figure(1, figsize=(16, 10))
            self.plot_episode(csv_file)
            plt.show()

    # Show convergence
    def show_progress(self):
        self.monitor_file = self.log_dir + '/monitor.csv'

        # Read progress file
        if not self.read_monitor_file():
            print('Progress data is missing')
            sys.exit(1)

        # Initialize graph
        plt.rcdefaults()
        plt.rcParams['font.size'] = 6
        plt.rcParams['lines.linewidth'] = 1.0
        plt.rcParams['legend.loc'] = 'lower right'

        self.fig = plt.figure(1, figsize=(16, 10))

        # Show widgets
        axcolor = 'lightgoldenrodyellow'
        self.axprogress = self.fig.add_axes([0.15, 0.10, 0.70, 0.15],
                                            facecolor=axcolor)
        self.axslider = self.fig.add_axes([0.15, 0.04, 0.70, 0.02],
                                          facecolor=axcolor)
        axfirst = self.fig.add_axes([0.15, 0.01, 0.03, 0.02])
        axlast = self.fig.add_axes([0.82, 0.01, 0.03, 0.02])
        axprev = self.fig.add_axes([0.46, 0.01, 0.03, 0.02])
        axnext = self.fig.add_axes([0.51, 0.01, 0.03, 0.02])

        # Slider is drawn in plot_progress()

        # First/Last button
        self.button_first = Button(axfirst,
                                   'First',
                                   color=axcolor,
                                   hovercolor='0.975')
        self.button_first.on_clicked(self.first_episode_num)
        self.button_last = Button(axlast,
                                  'Last',
                                  color=axcolor,
                                  hovercolor='0.975')
        self.button_last.on_clicked(self.last_episode_num)

        # Next/Prev button
        self.button_prev = Button(axprev,
                                  'Prev',
                                  color=axcolor,
                                  hovercolor='0.975')
        self.button_prev.on_clicked(self.prev_episode_num)
        self.button_next = Button(axnext,
                                  'Next',
                                  color=axcolor,
                                  hovercolor='0.975')
        self.button_next.on_clicked(self.next_episode_num)

        # Timer
        self.timer = self.fig.canvas.new_timer(interval=1000)
        self.timer.add_callback(self.check_update)
        self.timer.start()

        # Progress data
        self.axprogress.set_xmargin(0)
        self.axprogress.set_xlabel('Episodes')
        self.axprogress.set_ylabel('Reward')
        self.axprogress.grid(True)
        self.plot_progress()

        # Plot latest episode
        self.update_episode(self.num_episodes - 1)

        plt.show()

    def check_update(self):
        if self.read_monitor_file():
            self.plot_progress()

    def plot_progress(self):
        # Redraw all lines
        self.axprogress.lines = []
        self.axprogress.plot(self.reward, color='#1f77b4', label='Reward')
        # self.axprogress.plot(self.reward_mean, color='#ff7f0e', label='Reward (average)')
        self.axprogress.legend()
        # Redraw slider
        if self.sl_episode is None or int(round(
                self.sl_episode.val)) == self.num_episodes - 2:
            cur_ep = self.num_episodes - 1
        else:
            cur_ep = int(round(self.sl_episode.val))
        self.axslider.clear()
        # self.sl_episode = Slider(self.axslider, 'Episode (0..{})'.format(self.num_episodes - 1), 0, self.num_episodes - 1, valinit=self.num_episodes - 1, valfmt='%6.0f')
        self.sl_episode = Slider(self.axslider,
                                 'Episode (0..{})'.format(self.num_episodes -
                                                          1),
                                 0,
                                 self.num_episodes - 1,
                                 valinit=cur_ep,
                                 valfmt='%6.0f')
        self.sl_episode.on_changed(self.set_episode_num)

    def read_monitor_file(self):
        # For the very first call, Wait until monitor.csv is created
        if self.timestamp_csv is None:
            while not os.path.isfile(self.monitor_file):
                time.sleep(1)
            self.timestamp_csv = os.stat(
                self.monitor_file
            ).st_mtime - 1  # '-1' is a hack to prevent losing the first set of data

        num_ep = 0
        ts = os.stat(self.monitor_file).st_mtime
        if ts > self.timestamp_csv:
            # Monitor file is updated.
            self.timestamp_csv = ts
            f = open(self.monitor_file)
            firstline = f.readline()
            assert firstline.startswith('#')
            metadata = json.loads(firstline[1:])
            assert metadata['env_id'] == "EnergyPlus-v0"
            assert set(metadata.keys()) == {
                'env_id', 't_start'
            }, "Incorrect keys in monitor metadata"
            df = pd.read_csv(f, index_col=None)
            assert set(df.keys()) == {'l', 't',
                                      'r'}, "Incorrect keys in monitor logline"
            f.close()

            self.reward = []
            self.reward_mean = []
            self.episode_dirs = []
            self.num_episodes = 0
            for rew, len, time_ in zip(df['r'], df['l'], df['t']):
                self.reward.append(rew / len)
                self.reward_mean.append(rew / len)
                self.episode_dirs.append(
                    self.log_dir +
                    '/output/episode-{:08d}'.format(self.num_episodes))
                self.num_episodes += 1
            if self.num_episodes > self.num_episodes_last:
                self.num_episodes_last = self.num_episodes
                return True
        else:
            return False

    def update_episode(self, ep):
        self.plot_episode(ep)

    def set_episode_num(self, val):
        ep = int(round(self.sl_episode.val))
        self.update_episode(ep)

    def first_episode_num(self, val):
        self.sl_episode.set_val(0)

    def last_episode_num(self, val):
        self.sl_episode.set_val(self.num_episodes - 1)

    def prev_episode_num(self, val):
        ep = int(round(self.sl_episode.val))
        if ep > 0:
            ep -= 1
            self.sl_episode.set_val(ep)

    def next_episode_num(self, val):
        ep = int(round(self.sl_episode.val))
        if ep < self.num_episodes - 1:
            ep += 1
            self.sl_episode.set_val(ep)

    def show_statistics(self, title, series):
        print('{:25} ave={:5,.2f}, min={:5,.2f}, max={:5,.2f}, std={:5,.2f}'.
              format(title, np.average(series), np.min(series), np.max(series),
                     np.std(series)))

    def get_statistics(self, series):
        return np.average(series), np.min(series), np.max(series), np.std(
            series)

    def show_distrib(self, title, series):
        dist = [0 for i in range(1000)]
        for v in series:
            idx = int(math.floor(v * 10))
            if idx >= 1000:
                idx = 999
            if idx < 0:
                idx = 0
            dist[idx] += 1
        print(title)
        print(
            '    degree 0.0-0.9 0.0   0.1   0.2   0.3   0.4   0.5   0.6   0.7   0.8   0.9'
        )
        print(
            '    -------------------------------------------------------------------------'
        )
        for t in range(170, 280, 10):
            print('    {:4.1f}C {:5.1%}  '.format(
                t / 10.0,
                sum(dist[t:(t + 10)]) / len(series)),
                  end='')
            for tt in range(t, t + 10):
                print(' {:5.1%}'.format(dist[tt] / len(series)), end='')
            print('')

    def get_episode_list(self, log_dir='', csv_file=''):
        if (log_dir is not '' and csv_file is not '') or (log_dir is ''
                                                          and csv_file is ''):
            print('Either one of log_dir or csv_file must be specified')
            quit()
        if log_dir is not '':
            if not os.path.isdir(log_dir):
                print('energyplus_model.dump: {} is not a directory'.format(
                    log_dir))
                return
            print('energyplus_plot.dump: log={}'.format(log_dir))
            # self.log_dir = log_dir

            # Make a list of all episodes
            # Note: Somethimes csv file is missing in the episode directories
            # We accept gziped csv file also.
            csv_list = glob(log_dir + '/output/episode-????????/eplusout.csv') \
                       + glob(log_dir + '/output/episode-????????/eplusout.csv.gz')
            self.episode_dirs = list(
                set([os.path.dirname(i) for i in csv_list]))
            self.episode_dirs.sort()
            self.num_episodes = len(self.episode_dirs)
        else:  # csv_file != ''
            self.episode_dirs = [os.path.dirname(csv_file)]
            self.num_episodes = len(self.episode_dirs)

    # Model dependent methods
    @abstractmethod
    def read_episode(self, ep):
        pass

    @abstractmethod
    def plot_episode(self, ep):
        pass

    @abstractmethod
    def dump_timesteps(self, log_dir='', csv_file='', **kwargs):
        pass

    @abstractmethod
    def dump_episodes(self, log_dir='', csv_file='', **kwargs):
        pass
class Player(FuncAnimation):
    """
    The class makes a player with play, stop, and next buttons and a frame slider.
    """
    def __init__(self,
                 fig,
                 func,
                 init_func=None,
                 fargs=None,
                 save_count=None,
                 button_color='yellow',
                 bg_color='red',
                 dis_start=0,
                 dis_stop=100,
                 pos=(0.125, 0.05),
                 **kwargs):
        """
        initialization
        :param fig: matplotlib fifure object
        :param func: user-defined function which takes a integer (frame number) as an input
        :param init_func: user-defined initial function used by the FuncAnimation class
        :param fargs: arguments of func, used by FuncAnimation class
        :param save_count: save count arg used by FuncAnimation class
        :param button_color: string, color of the buttons of the player
        :param bg_color: string, hovercolor of the buttons and slider
        :param dis_start: int, start frame number
        :param dis_stop: int, stop frame number
        :param pos: length 2 tuple, position of the buttons
        :param kwargs: kwargs for FuncAnimation class
        """
        # setting up the index
        self.start_ind = dis_start
        self.stop_ind = dis_stop
        self.dis_length = self.stop_ind - self.start_ind
        self.ind = self.start_ind

        self.runs = True
        self.forwards = True
        self.fig = fig
        self.fig.set_facecolor('k')
        self.button_color = button_color
        self.bg_color = bg_color

        self.func = func
        self.setup(pos)
        FuncAnimation.__init__(self,
                               self.fig,
                               self.func,
                               frames=self.play(),
                               init_func=init_func,
                               fargs=fargs,
                               save_count=save_count,
                               **kwargs)

    @property
    def ind(self):
        return self._ind

    @ind.setter
    def ind(self, val):
        self._ind = val
        self._ind -= self.start_ind
        self._ind %= (self.dis_length)
        self._ind += self.start_ind

    def play(self):
        """
        play function
        """
        while self.runs:

            self.ind = self.ind + self.forwards - (not self.forwards)
            self._update()

            yield self.ind

    def start(self):
        self.runs = True
        self._update()
        self.event_source.start()

    def stop(self, event=None):
        self.runs = False
        self._update()
        self.event_source.stop()

    def forward(self, event=None):
        self.forwards = True
        self.start()

    def backward(self, event=None):
        self.forwards = False
        self.start()

    def oneforward(self, event=None):
        self.forwards = True
        self.onestep()

    def onebackward(self, event=None):
        self.forwards = False
        self.onestep()

    def onestep(self):
        if self.forwards:
            self.ind += 1
        else:
            self.ind -= 1

        self.func(self.ind)

        self._update()
        self.fig.canvas.draw_idle()

    def _update(self):
        self.slider.set_val(self.ind)

    def __set_slider(self, val):
        val = int(val)
        self.ind = val
        #self.func(self.ind)

    def setup(self, pos):
        """
        Setting up the buttons and the slider
        :param pos: length 2 tuple, position of the axes for buttons and tuples
        :return:
        """
        playerax = self.fig.add_axes([pos[0], pos[1], 0.22, 0.04])
        divider = mpl_toolkits.axes_grid1.make_axes_locatable(playerax)
        bax = divider.append_axes("right", size="80%", pad=0.05)
        sax = divider.append_axes("right", size="80%", pad=0.05)
        fax = divider.append_axes("right", size="80%", pad=0.05)
        ofax = divider.append_axes("right", size="100%", pad=0.05)
        sliderax = self.fig.add_axes(
            (pos[0], pos[1] - 0.045, 0.5, 0.04),
            facecolor=self.bg_color)  # 'lemonchiffon')

        self.button_oneback = matplotlib.widgets.Button(
            playerax,
            color=self.button_color,
            hovercolor=self.bg_color,
            label='$\u29CF$')  # , label=r'$\u29CF$')
        self.button_back = matplotlib.widgets.Button(
            bax,
            color=self.button_color,
            hovercolor=self.bg_color,
            label='$\u25C0$')  # r'$\u25C0$')
        self.button_stop = matplotlib.widgets.Button(
            sax,
            color=self.button_color,
            hovercolor=self.bg_color,
            label='$\u25A0$')  # , label=r'$\u25A0$')
        self.button_forward = matplotlib.widgets.Button(
            fax,
            color=self.button_color,
            hovercolor=self.bg_color,
            label='$\u25B6$')  # , label=r'$\u25B6$')
        self.button_oneforward = matplotlib.widgets.Button(
            ofax,
            color=self.button_color,
            hovercolor=self.bg_color,
            label='$\u29D0$')  # , label=r'$\u29D0$')
        self.button_oneback.on_clicked(self.onebackward)
        self.slider = Slider(sliderax,
                             label='',
                             valfmt='%0.0f',
                             valmin=0,
                             valmax=self.stop_ind - 1,
                             valinit=self.ind,
                             color='black',
                             fc=self.button_color)  # , snap='True')
        self.slider.label.set_color(self.button_color)
        # self.slider.valtext.set_color(self.button_color)
        self.slider.valtext.set_position((0.5, 0.5))

        self.slider.set_val(self.ind)

        self.button_back.on_clicked(self.backward)
        self.button_stop.on_clicked(self.stop)
        self.button_forward.on_clicked(self.forward)
        self.button_oneforward.on_clicked(self.oneforward)
        self.slider.on_changed(self.__set_slider)
Exemple #48
0
class CubeDisplayBase(ImageDisplay):
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def __init__(self, ax, data, coords=None, **kwargs):
        '''
        Image display for 3D data. Implements frame slider and image scroll.  
        Optionally also displays apertures if coordinates provided.
        
        subclasses must implement set_frame, get_frame methods
        
        Parameters
        ----------
        ax      :       Axes object
            Axes on which to display
        data    :       array-like
            initial display data
        coords  :       optional, np.ndarray
            coordinates of apertures to display.  This must be an np.ndarray with
            shape (k, N, 2) where k is the number of apertures per frame, and N 
            is the number of frames
        
        kwargs are passed directly to ImageDisplay.
        '''
        #setup image display
        self.autoscale = kwargs.pop('autoscale', 'percentile') #TODO: move up??
        ImageDisplay.__init__(self, ax, data, **kwargs)
        
        #self.coords = coords
        
        #setup frame slider
        self._frame = 0
        self.fsax = self.divider.append_axes('bottom', size=0.2, pad=0.25)
        #TODO: elliminated this SHIT Slider class!!!
        self.frame_slider = Slider(self.fsax, 'frame', 0, len(self), valfmt='%d')
        self.frame_slider.on_changed(self.set_frame)
        if self.use_blit:
            self.frame_slider.drawon = False   
            
        #save background for blitting
        fig = ax.figure
        self.background = fig.canvas.copy_from_bbox(ax.bbox)

        #enable frame scroll
        fig.canvas.mpl_connect('scroll_event', self._scroll)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #@property
    #def has_coords(self):
        #return self.coords is not None
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def _needs_drawing(self):
        #NOTE: this method is a temp hack to return the artists that need to be 
        #drawn when the frame is changed (for blitting). This is in place while
        #the base class is being refined.
        #TODO: proper observers as modelled on draggables.machinery
        
        needs_drawing = [self.imgplt]
        if self.has_hist:
            needs_drawing.extend(self.patches)      #TODO: PatchCollection...
        
        if self.autoscale:
            needs_drawing.extend(self.sliders.sliders)
            
            
        ##[#self.imgplt.colorbar, #self.sliders.centre_knob])
        
        
        return needs_drawing
        
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def get_data(self, i):
        return self.data[i]
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def get_frame(self):
        return self._frame

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #@expose.args()
    def set_frame(self, i, draw=False):
        '''Set frame data. draw if requested '''
        i %= len(self)          #wrap around! (eg. scroll past end ==> go to beginning)
        i = int(round(i, 0))    #make sure we have an int
        self._frame = i
        
        data = self.get_data(i)
        
        #ImageDisplay.draw_blit??
        #set the slider axis limits
        dmin, dmax = data.min(), data.max()
        self.sliders.ax.set_ylim(dmin, dmax)
        self.sliders.valmin, self.sliders.valmax = dmin, dmax
        #needs_drawing.append()???
        
        #set the image data
        self.imgplt.set_data(data)
        #needs_drawing = [self.imgplt]
        
        
        if self.autoscale:
            #set the slider positiions / color limits
            vmin, vmax = self.get_autoscale_limits(data, autoscale=self.autoscale)
            self.imgplt.set_clim(vmin, vmax)
            self.sliders.set_positions((vmin, vmax))
            
        
        #TODO: update hisogram values etc...
        
        #ImageDisplay.draw_blit??
        if draw:
            needs_drawing = self._needs_drawing()
            self.draw_blit(needs_drawing)

    frame = property(get_frame, set_frame)

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def _scroll(self, event):
        self.frame += [-1, +1][event.button == 'up']
        self.frame_slider.set_val(self.frame)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #@expose.args()
    def draw_blit(self, artists):
        
        #print('draw_blit')
        
        fig = self.ax.figure
        fig.canvas.restore_region(self.background)
        
        for art in artists:
            try:
                self.ax.draw_artist(art)
            except Exception as err:
                print('drawing FAILED', art)
                traceback.print_exc()
                
        
        fig.canvas.blit(fig.bbox)
        
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def cooDisplayFormatter(self, x, y):
        s = ImageDisplay.cooDisplayFormatter(self, x,y)
        return 'frame %d: %s'%(self.frame, s)
Exemple #49
0
class SpecPlotter:
    def __init__(self, q, freqs, x_max, hist_time, time_res, maxscale):
        self.HORIZONTAL_PIXELS = len(freqs) + 1
        self.q = q
        self.freqs = freqs
        self.x_max = x_max
        self.hist_time = hist_time
        self.time_res = time_res
        self.maxscale = maxscale

        self.hist_segments = floor(self.hist_time / self.time_res)
        plt.ion()
        self.figure, (self.ax, self.ax2) = plt.subplots(2)
        self.figure.canvas.mpl_connect("close_event", self.exit_evt)
        self.Z = np.zeros((self.hist_segments, self.HORIZONTAL_PIXELS))
        self.highest_peak = 0.0 if maxscale is None else maxscale
        self.scale = self.highest_peak
        self.scale_slider_touched = False
        plt.subplots_adjust(bottom=0.1)
        self.ax_scale_slider = plt.axes([0.15, 0.03, 0.7, 0.03])
        self.scale_slider = Slider(self.ax_scale_slider, "Scale", 0, self.highest_peak, self.highest_peak)
        self.scale_slider.on_changed(self.scale_slider_changed)
        self.X = np.linspace(self.freqs[0], self.freqs[-1], self.HORIZONTAL_PIXELS)
        self.row = np.zeros(len(self.freqs))

    def update(self):
        while True:
            while not self.q.empty():
                self.row = self.q.get()
                self.Z = np.vstack((self.Z, interp1d(self.freqs, self.row, kind="linear")(self.X)))
                if self.Z.shape[0] > self.hist_segments:
                    self.Z = np.delete(self.Z, 0, axis=0)
                peak_height = np.abs(self.row).max()
                if self.highest_peak < peak_height:
                    self.highest_peak = peak_height
                    if self.maxscale is None:
                        if self.scale == self.scale_slider.valmax:
                            self.scale = self.highest_peak
                        self.ax_scale_slider.set_xlim(0, self.highest_peak)
                        self.scale_slider.valmax = self.highest_peak
                        self.scale_slider.set_val(self.scale)

            self.ax.clear()
            self.ax.imshow(
                self.Z,
                extent=(0, self.x_max, 0, self.hist_time),
                cmap="Greys",
                vmin=0,
                vmax=self.scale,
                interpolation="none",
                aspect="auto",
            )
            self.ax.set_title("Spectrogram")
            self.ax.set_xlabel("Hz")

            self.ax2.clear()
            self.ax2.plot(self.freqs, self.row)

            self.ax2.axis([0, self.x_max, 0, self.scale])

            # Draw canvas
            plt.draw()
            plt.pause(0.001)

    def exit_evt(self, evt):
        os._exit(0)

    def scale_slider_changed(self, v):
        self.scale = v
Exemple #50
0
class viscm_editor(object):
    def __init__(self, uniform_space="CAM02-UCS", min_Jp=15, max_Jp=95, xp=None, yp=None):
        from .bezierbuilder import BezierModel, BezierBuilder

        self._uniform_space = uniform_space

        self.figure = plt.figure()
        axes = _viscm_editor_axes(self.figure)

        ax_btn_wireframe = plt.axes([0.7, 0.15, 0.1, 0.025])
        self.btn_wireframe = Button(ax_btn_wireframe, "Show 3D gamut")
        self.btn_wireframe.on_clicked(self.plot_3d_gamut)

        ax_btn_wireframe = plt.axes([0.81, 0.15, 0.1, 0.025])
        self.btn_save = Button(ax_btn_wireframe, "Save colormap")
        self.btn_save.on_clicked(self.save_colormap)

        ax_btn_props = plt.axes([0.81, 0.1, 0.1, 0.025])
        self.btn_props = Button(ax_btn_props, "Properties")
        self.btn_props.on_clicked(self.show_viscm)
        self.prop_windows = []

        axcolor = "None"
        ax_jp_min = plt.axes([0.1, 0.1, 0.5, 0.03], axisbg=axcolor)
        ax_jp_min.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap="gray")
        ax_jp_min.set_xlim(0, 100)

        ax_jp_max = plt.axes([0.1, 0.15, 0.5, 0.03], axisbg=axcolor)
        ax_jp_max.imshow(np.linspace(0, 100, 101).reshape(1, -1), cmap="gray")

        self.jp_min_slider = Slider(ax_jp_min, r"$J'_\mathrm{min}$", 0, 100, valinit=min_Jp)
        self.jp_max_slider = Slider(ax_jp_max, r"$J'_\mathrm{max}$", 0, 100, valinit=max_Jp)

        self.jp_min_slider.on_changed(self._jp_update)
        self.jp_max_slider.on_changed(self._jp_update)

        if xp is None:
            xp = [-2.0591553836234482, 59.377014829142524, 43.552546744036135, 4.7670857511283202, -9.5059638942617539]

        if yp is None:
            yp = [-25.664893617021221, -21.941489361702082, 38.874113475177353, 20.567375886524871, 32.047872340425585]

        self.bezier_model = BezierModel(xp, yp)
        self.cmap_model = BezierCMapModel(
            self.bezier_model, self.jp_min_slider.val, self.jp_max_slider.val, uniform_space
        )
        self.highlight_point_model = HighlightPointModel(self.cmap_model, 0.5)

        self.bezier_builder = BezierBuilder(axes["bezier"], self.bezier_model)
        self.bezier_gamut_viewer = GamutViewer2D(axes["bezier"], self.highlight_point_model, uniform_space)
        tmp = HighlightPoint2DView(axes["bezier"], self.highlight_point_model)
        self.bezier_highlight_point_view = tmp

        # draw_pure_hue_angles(axes['bezier'])
        axes["bezier"].set_xlim(-100, 100)
        axes["bezier"].set_ylim(-100, 100)

        self.cmap_view = CMapView(axes["cm"], self.cmap_model)
        self.cmap_highlighter = HighlightPointBuilder(axes["cm"], self.highlight_point_model)

        print("Click sliders at bottom to change min/max lightness")
        print("Click on colorbar to adjust gamut view")
        print("Click-drag to move control points, ")
        print("  shift-click to add, control-click to delete")

    def plot_3d_gamut(self, event):
        fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
        self.wireframe_view = WireframeView(ax, self.cmap_model, self.highlight_point_model, self._uniform_space)
        plt.show()

    def save_colormap(self, event):
        import textwrap

        template = textwrap.dedent(
            """
        from matplotlib.colors import ListedColormap
        from numpy import nan, inf

        # Used to reconstruct the colormap in viscm
        parameters = {{'xp': {xp},
                      'yp': {yp},
                      'min_Jp': {min_Jp},
                      'max_Jp': {max_Jp}}}

        cm_data = {array_list}

        test_cm = ListedColormap(cm_data, name=__file__)


        if __name__ == "__main__":
            import matplotlib.pyplot as plt
            import numpy as np

            try:
                from viscm import viscm
                viscm(test_cm)
            except ImportError:
                print("viscm not found, falling back on simple display")
                plt.imshow(np.linspace(0, 100, 256)[None, :], aspect='auto',
                           cmap=test_cm)
            plt.show()
        """
        )

        rgb, _ = self.cmap_model.get_sRGB(num=256)
        with open("/tmp/new_cm.py", "w") as f:
            array_list = np.array2string(rgb, max_line_width=78, prefix="cm_data = ", separator=",")

            xp, yp = self.cmap_model.bezier_model.get_control_points()

            data = dict(
                array_list=array_list, xp=xp, yp=yp, min_Jp=self.cmap_model.min_Jp, max_Jp=self.cmap_model.max_Jp
            )

            f.write(template.format(**data))

            print("*" * 50)
            print("Saved colormap to /tmp/new_cm.py")
            print("*" * 50)

    def show_viscm(self, event):
        cm = LinearSegmentedColormap.from_list("test_cm", self.cmap_model.get_sRGB(num=256)[0])
        self.prop_windows.append(viscm(cm, name="test_cm"))
        plt.show()

    def _jp_update(self, val):
        jp_min = self.jp_min_slider.val
        jp_max = self.jp_max_slider.val

        smallest, largest = min(jp_min, jp_max), max(jp_min, jp_max)
        if (jp_min > smallest) or (jp_max < largest):
            self.jp_min_slider.set_val(smallest)
            self.jp_max_slider.set_val(largest)

        self.cmap_model.set_Jp_minmax(smallest, largest)
Exemple #51
0
class FigureContainer():
    """instantiates figure and axes, creates all other objects. Holds the
    animation"""
    def __init__(self, file):
        #because I'm using keys that matplotlib also uses, remove some bindings
        plt.rcParams['keymap.fullscreen'] = '{'
        plt.rcParams['keymap.yscale'] = '}'
        self.fig = plt.figure(facecolor='w')
        self.fig.set_size_inches(16, 9)
        g = gspec.GridSpec(5,
                           2,
                           height_ratios=[2, 6, 4, 1, 6],
                           width_ratios=[10, 3])
        g.update(left=0.05,
                 right=0.95,
                 wspace=0.02,
                 hspace=.04,
                 bottom=.02,
                 top=.98)

        #add the axes in the subplot areas desired
        self.spatial_axis = self.fig.add_subplot(g[1, 0])
        self.vid_axis = self.fig.add_subplot(g[4, 0])
        self.temporal_axis = self.fig.add_subplot(g[2, 0])
        #self.select_axis = self.fig.add_subplot(g[:,1])
        self.sel_panel = SelectPanel(self.fig, g)
        #create an axes to hold space for the buttons
        #self.fig.add_subplot(g[0,0]).set_axis_off()

        #add a slider
        slid_ax = self.fig.add_subplot(g[3, 0])
        #TODO: does this actually force it to an int
        self.slide = Slider(slid_ax,
                            'Frame',
                            0,
                            100,
                            valinit=0,
                            valfmt='%0.0f')

        self.print_manager = PrintManager(self.temporal_axis,
                                          self.spatial_axis, self.sel_panel,
                                          self.vid_axis, self.slide, file)
        self.sel_panel.set_print_manager(self.print_manager)
        self.set_slider_range()
        self.prev_frame = self.slide.val
        #activate pick events
        self.fig.canvas.mpl_connect('pick_event', self.print_manager.on_pick)
        self.fig.canvas.mpl_connect('key_press_event',
                                    self.print_manager.on_key_press)
        g2 = gspec.GridSpecFromSubplotSpec(1,
                                           4,
                                           subplot_spec=g[0, 0],
                                           wspace=0.1,
                                           hspace=0.1)
        #g2.update(left=0.05, right=0.95, wspace=0.02, hspace=.04,
        #         bottom = .2, top = .8)
        #ax=self.fig.add_subplot(g2[0,0])
        #Button(ax, 'Load File').on_clicked(
        #                            self.print_manager.initiate_split_window)
        ax = self.fig.add_subplot(g2[0, 0])
        self.sav_but = Button(ax, 'Save Changes')
        self.sav_but.on_clicked(self.print_manager.save)
        #ax=self.fig.add_subplot(g2[0,2])
        #Button(ax, 'Pause Video')
        #ax=self.fig.add_subplot(g2[0,3])
        #Button(ax, 'View Deleted')

        #and start the animation
        self.anim = animation.FuncAnimation(self.fig,
                                            self.update_func,
                                            fargs=(),
                                            interval=15,
                                            repeat=True)
        plt.show()

    def update_func(self, j):
        """based on the value of the slider, update the video
        """
        i = int(self.slide.val) + 1
        if i > self.slide.valmax:
            i = self.slide.valmin
        self.print_manager.change_frame(self.prev_frame, i)
        #self.slide.set_val(i%len(self.left_panel.frames))
        self.slide.set_val(i)
        self.prev_frame = i
        return j

    def set_slider_range(self):
        """set the slider to range between combo_prints first frame and
        last frame"""
        self.slide.set_val(self.print_manager.combo_prints.first_frame.min())
        self.slide.valmin = self.print_manager.combo_prints.first_frame.min()
        self.slide.valmax = self.print_manager.combo_prints.last_frame.max()
        self.slide.ax.set_xlim(self.slide.valmin, self.slide.valmax)
Exemple #52
0
class PVSlicer(object):

    def __init__(self, filename, backend="Qt4Agg", clim=None):

        self.filename = filename

        try:
            from spectral_cube import SpectralCube
            cube = SpectralCube.read(filename, format='fits')
            self.array = cube._data
        except:
            warnings.warn("spectral_cube package is not available - using astropy.io.fits directly")
            from astropy.io import fits
            self.array = fits.getdata(filename)
            if self.array.ndim != 3:
                raise ValueError("dataset does not have 3 dimensions (install the spectral_cube package to avoid this error)")

        self.backend = backend

        import matplotlib as mpl
        mpl.use(self.backend)
        import matplotlib.pyplot as plt

        self.fig = plt.figure(figsize=(14, 8))

        self.ax1 = self.fig.add_axes([0.1, 0.1, 0.4, 0.7])

        if clim is None:
            warnings.warn("clim not defined and will be determined from the data")
            # To work with large arrays, sub-sample the data
            # (but don't do it for small arrays)
            n1 = max(self.array.shape[0] / 10, 1)
            n2 = max(self.array.shape[1] / 10, 1)
            n3 = max(self.array.shape[2] / 10, 1)
            sub_array = self.array[::n1,::n2,::n3]
            cmin = np.min(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)])
            cmax = np.max(sub_array[~np.isnan(sub_array) & ~np.isinf(sub_array)])
            crange = cmax - cmin
            self._clim = (cmin - crange, cmax + crange)
        else:
            self._clim = clim

        self.slice = int(round(self.array.shape[0] / 2.))

        from matplotlib.widgets import Slider

        self.slice_slider_ax = self.fig.add_axes([0.1, 0.95, 0.4, 0.03])
        self.slice_slider_ax.set_xticklabels("")
        self.slice_slider_ax.set_yticklabels("")
        self.slice_slider = Slider(self.slice_slider_ax, "3-d slice", 0, self.array.shape[0], valinit=self.slice, valfmt="%i")
        self.slice_slider.on_changed(self.update_slice)
        self.slice_slider.drawon = False

        self.image = self.ax1.imshow(self.array[self.slice, :,:], origin='lower', interpolation='nearest', vmin=self._clim[0], vmax=self._clim[1], cmap=plt.cm.gray)

        self.vmin_slider_ax = self.fig.add_axes([0.1, 0.90, 0.4, 0.03])
        self.vmin_slider_ax.set_xticklabels("")
        self.vmin_slider_ax.set_yticklabels("")
        self.vmin_slider = Slider(self.vmin_slider_ax, "vmin", self._clim[0], self._clim[1], valinit=self._clim[0])
        self.vmin_slider.on_changed(self.update_vmin)
        self.vmin_slider.drawon = False

        self.vmax_slider_ax = self.fig.add_axes([0.1, 0.85, 0.4, 0.03])
        self.vmax_slider_ax.set_xticklabels("")
        self.vmax_slider_ax.set_yticklabels("")
        self.vmax_slider = Slider(self.vmax_slider_ax, "vmax", self._clim[0], self._clim[1], valinit=self._clim[1])
        self.vmax_slider.on_changed(self.update_vmax)
        self.vmax_slider.drawon = False

        self.grid1 = None
        self.grid2 = None
        self.grid3 = None

        self.ax2 = self.fig.add_axes([0.55, 0.1, 0.4, 0.7])

        # Add slicing box
        self.box = SliceCurve(colors=(0.8, 0.0, 0.0))
        self.ax1.add_collection(self.box)
        self.movable = MovableSliceBox(self.box, callback=self.update_pv_slice)
        self.movable.connect()

        # Add save button
        from matplotlib.widgets import Button
        self.save_button_ax = self.fig.add_axes([0.65, 0.90, 0.20, 0.05])
        self.save_button = Button(self.save_button_ax, 'Save slice to FITS')
        self.save_button.on_clicked(self.save_fits)
        self.file_status_text = self.fig.text(0.75, 0.875, "", ha='center', va='center')
        self.set_file_status(None)

        self.set_file_status(None)
        self.pv_slice = None

        self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.click)

    def set_file_status(self, status, filename=None):
        if status == 'instructions':
            self.file_status_text.set_text('Please enter filename in terminal')
            self.file_status_text.set_color('red')
        elif status == 'saved':
            self.file_status_text.set_text('File successfully saved to {0}'.format(filename))
            self.file_status_text.set_color('green')
        else:
            self.file_status_text.set_text('')
            self.file_status_text.set_color('black')
        self.fig.canvas.draw()

    def click(self, event):

        if event.inaxes != self.ax2:
            return

        self.slice_slider.set_val(event.ydata)

    def save_fits(self, *args, **kwargs):

        self.set_file_status('instructions')

        print("Enter filename: ", end='')
        try:
            plot_name = raw_input()
        except NameError:
            plot_name = input()

        if self.pv_slice is None:
            return

        try:
            self.pv_slice.writeto(plot_name, overwrite=True)
        except TypeError:
            self.pv_slice.writeto(plot_name, clobber=True)

        print("Saved file to: ", plot_name)

        self.set_file_status('saved', filename=plot_name)

    def update_pv_slice(self, box):

        path = Path(zip(box.x, box.y))
        path.width = box.width

        self.pv_slice = extract_pv_slice(self.array, path)

        self.ax2.cla()
        self.ax2.imshow(self.pv_slice.data, origin='lower', aspect='auto', interpolation='nearest')

        self.fig.canvas.draw()

    def show(self, block=True):
        import matplotlib.pyplot as plt
        plt.show(block=block)

    def update_slice(self, pos=None):

        if self.array.ndim == 2:
            self.image.set_array(self.array)
        else:
            self.slice = int(round(pos))
            self.image.set_array(self.array[self.slice, :, :])

        self.fig.canvas.draw()

    def update_vmin(self, vmin):
        if vmin > self._clim[1]:
            self._clim = (self._clim[1], self._clim[1])
        else:
            self._clim = (vmin, self._clim[1])
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()

    def update_vmax(self, vmax):
        if vmax < self._clim[0]:
            self._clim = (self._clim[0], self._clim[0])
        else:
            self._clim = (self._clim[0], vmax)
        self.image.set_clim(*self._clim)
        self.fig.canvas.draw()
Exemple #53
0
def view_components(estimates, img, idx):
    """ View spatial and temporal components interactively
    Args:
        estimates: dict
            estimates dictionary contain results of VolPy
            
        img: 2-D array
            summary images for detection
            
        idx: list
            index of selected neurons
    """
    n = len(idx)
    fig = plt.figure(figsize=(10, 10))

    axcomp = plt.axes([0.05, 0.05, 0.9, 0.03])
    ax1 = plt.axes([0.05, 0.55, 0.4, 0.4])
    ax3 = plt.axes([0.55, 0.55, 0.4, 0.4])
    ax2 = plt.axes([0.05, 0.1, 0.9, 0.4])
    s_comp = Slider(axcomp, 'Component', 0, n, valinit=0)
    vmax = np.percentile(img, 98)

    def arrow_key_image_control(event):

        if event.key == 'left':
            new_val = np.round(s_comp.val - 1)
            if new_val < 0:
                new_val = 0
            s_comp.set_val(new_val)

        elif event.key == 'right':
            new_val = np.round(s_comp.val + 1)
            if new_val > n:
                new_val = n
            s_comp.set_val(new_val)

    def update(val):
        i = np.int(np.round(s_comp.val))
        print(f'Component:{i}')

        if i < n:

            ax1.cla()
            imgtmp = estimates['weights'][idx][i]
            ax1.imshow(imgtmp,
                       interpolation='None',
                       cmap=plt.cm.gray,
                       vmax=np.max(imgtmp) * 0.5,
                       vmin=0)
            ax1.set_title(f'Spatial component {i+1}')
            ax1.axis('off')

            ax2.cla()
            ax2.plot(estimates['t'][idx][i], alpha=0.8)
            ax2.plot(estimates['t_sub'][idx][i])
            ax2.plot(estimates['t_rec'][idx][i], alpha=0.4, color='red')
            ax2.plot(estimates['spikes'][idx][i],
                     1.05 * np.max(estimates['t'][idx][i]) *
                     np.ones(estimates['spikes'][idx][i].shape),
                     color='r',
                     marker='.',
                     fillstyle='none',
                     linestyle='none')
            ax2.set_title(f'Signal and spike times {i+1}')
            ax2.legend(labels=['t', 't_sub', 't_rec', 'spikes'])
            ax2.text(0.1,
                     0.1,
                     f'snr:{round(estimates["snr"][idx][i],2)}',
                     horizontalalignment='center',
                     verticalalignment='center',
                     transform=ax2.transAxes)
            ax2.text(0.1,
                     0.07,
                     f'num_spikes: {len(estimates["spikes"][idx][i])}',
                     horizontalalignment='center',
                     verticalalignment='center',
                     transform=ax2.transAxes)
            ax2.text(0.1,
                     0.04,
                     f'locality_test: {estimates["locality"][idx][i]}',
                     horizontalalignment='center',
                     verticalalignment='center',
                     transform=ax2.transAxes)

            ax3.cla()
            ax3.imshow(img, interpolation='None', cmap=plt.cm.gray, vmax=vmax)
            imgtmp2 = imgtmp.copy()
            imgtmp2[imgtmp2 == 0] = np.nan
            ax3.imshow(imgtmp2,
                       interpolation='None',
                       alpha=0.5,
                       cmap=plt.cm.hot)
            ax3.axis('off')

    s_comp.on_changed(update)
    s_comp.set_val(0)
    fig.canvas.mpl_connect('key_release_event', arrow_key_image_control)
    plt.show()
Exemple #54
0
class ytViewer(object):
    def __init__(self, filenames, fold=19277, nmax=100,NORM=True,dtype=int8,DEB=0,shear=0.):
        self.UPDATE = False
        self.color  = False
        self.CMAP   = ['jet','seismic','Greys','plasma']
        self.COLORS = ['b','g','r']
        
        self.filenames   = filenames
        self.NORM        = NORM
        self.NMAX        = nmax
        self.fold        = fold
        self.increment   = 5
        self.index       = 0
        self.shear       = shear
        self.vmin        = -125
        self.vmax        = 125
        self.HORIZ_VAL   = 0.05
        
        for i in range(len(self.filenames)):
            exec("self.remove_len1%d = 0" %i)
            exec("self.remove_len2%d = 1" %i)

        self.Y0 = 0
        
        ### To set the data arrays ###
        for i in range(len(self.filenames)):
            exec("self.data%d = fromfile(filenames[i],dtype=dtype)" %i)
            try:
                pass
            except:
                print '\nYou must provide existing filenames\n'
                sys.exit()
                
        self.max_index = int(len(self.data0)/self.fold)
        if not(self.NMAX):
            self.NMAX = self.max_index
        
        for i in range(len(self.filenames)):
            exec("self.folded_data_orig2%d = self.data%d[:self.max_index*self.fold].reshape(self.max_index,self.fold)" %(i,i))
            exec("self.folded_data_orig%d  = array(self.folded_data_orig2%d)" %(i,i))
            exec("self.folded_data_orig3%d = array(self.folded_data_orig%d)" %(i,i))
            exec("self.folded_data%d       = self.folded_data_orig3%d[:self.NMAX]" %(i,i))
        
        ##################################################################
        ################## Start creating the figure #####################
        self.fig = figure(figsize=(16,7))
        
        if len(self.filenames)==1:
            self.declare_axis_1channel()
        elif len(self.filenames)==2:
            self.declare_axis_2channel()
        elif len(self.filenames)==3:
            self.declare_axis_3channel()
            
        for i in range(len(self.filenames)):
            if not self.NORM:
                exec("self.im%d = self.ax%d.imshow(self.folded_data%d, interpolation='nearest', aspect='auto',origin='lower', vmin=self.vmin, vmax=self.vmax)" %(i,i,i))
            else:
                exec("self.im%d = self.ax%d.imshow(self.folded_data%d, interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data%d.min(), vmax=self.folded_data%d.max())" %(i,i,i,i,i))

        for i in range(len(self.filenames)):
            exec("self.cursor%d = Cursor(self.ax%d, useblit=True, color='red', linewidth=2)" %(i,i))

        self.axhh = axes([0.02,0.25,0.12,0.62])
        for i in range(len(self.filenames)):
            exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i))
            exec("self.axh%d.set_xlim(0,len(self.folded_data%d[0,:]))" %(i,i))
            exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(self.NMAX),self.COLORS[i])" %(i,i))
        self.axhh.set_ylim(0,self.NMAX-1)
        if not self.NORM:
            for i in range(len(self.filenames)):
                exec("self.axh%d.set_ylim(self.vmin, self.vmax)" %i)
            self.axhh.set_xlim(self.vmin, self.vmax)
        else:
            for i in range(len(self.filenames)):
                 exec("self.axh%d.set_ylim(self.folded_data%d.min(), self.folded_data%d.max())" %(i,i,i))
            if len(self.filenames)==1:
                LIM_MIN = self.folded_data0.mean(1).min()
                LIM_MAX = self.folded_data0.mean(1).max()
            elif len(self.filenames)==2:
                LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min())
                LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max())
            elif len(self.filenames)==3:
                LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min(),self.folded_data2.mean(1).min())
                LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max(),self.folded_data2.mean(1).max())
            self.axhh.set_xlim(LIM_MIN-1,LIM_MAX+1)
        
        # create 'remove_len1' slider
        for i in range(len(self.filenames)):
            exec("self.remove_len1%d_slider   = Slider(self.remove_len1%d_sliderax,'beg',0.,self.fold,self.remove_len1%d,'%s')" %(i,i,i,'%d'))
            exec("self.remove_len1%d_slider.on_changed(self.update_tab)" %i)
        
        # create 'remove_len2' slider
        for i in range(len(self.filenames)):
            exec("self.remove_len2%d_slider   = Slider(self.remove_len2%d_sliderax,'end',1.,self.fold,self.remove_len2%d,'%s')" %(i,i,i,'%d'))
            exec("self.remove_len2%d_slider.on_changed(self.update_tab)" %i)
        # create 'index' slider
        self.index_sliderax = axes([0.175,0.975,0.775,0.02])
        self.index_slider   = Slider(self.index_sliderax,'index',0,self.max_index-self.increment,0,'%d')
        self.index_slider.on_changed(self.update_tab)
        # create 'nmax' slider
        self.nmax_sliderax = axes([0.175,0.955,0.775,0.02])
        self.nmax_slider   = Slider(self.nmax_sliderax,'nmax',0,self.max_index,self.NMAX,'%d')
        self.nmax_slider.on_changed(self.update_tab)
        # create 'shear' slider
        self.shear_sliderax = axes([0.175,0.935,0.775,0.02])
        self.shear_slider   = Slider(self.shear_sliderax,'Shear',-0.5,0.5,self.shear,'%1.2f')
        self.shear_slider.on_changed(self.update_shear)
        
        cid  = self.fig.canvas.mpl_connect('motion_notify_event', self.mousemove)
        cid2 = self.fig.canvas.mpl_connect('key_press_event', self.keypress)

        VERT_VAL = -4
        font0 = FontProperties()
        font1 = font0.copy()
        font1.set_weight('bold')
        mpl.pyplot.text(-0.72,-32+VERT_VAL,'Useful keys:',fontsize=18,fontproperties=font1)
        mpl.pyplot.text(-0.72,-41+VERT_VAL,'"c" to change colormap\n "v" to change vertical\n      /colorscale\n " " to pause\n "w"/"x" set Ch1 REMOVE\n       sliders values to Ch2/3\n "t" Retrigger mode\n (NOT TOO MUCH POINTS) \n "q" to exit',fontsize=18)
        
        self.axe_toggledisplay  = self.fig.add_axes([0.,0.,1.0,0.02])
        if self.UPDATE:
            self.plot_circle(0,0,2,fc='#00FF7F')
        else:
            self.plot_circle(0,0,2,fc='#FF4500')
        mpl.pyplot.axis('off')
        
        gobject.idle_add(self.update_plot)
        show()
        
    ### BEGIN main loop ###
    def update_plot(self):
        while self.UPDATE:
            ### Compute the array to plot ###
            for i in range(len(self.filenames)):
                exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i))
            
            ### Update picture ###
            for i in range(len(self.filenames)):
                exec("self.im%d.set_data(self.folded_data%d)" %(i,i))
                exec("self.hline%d.set_ydata(self.folded_data%d[self.Y0,:])" %(i,i))
                exec("self.hhline%d.set_xdata(self.folded_data%d.mean(1))" %(i,i))
            self.index = self.index + self.increment
            self.index_slider.set_val(self.index)
            self.fig.canvas.draw()

            return True
        return False
    ### END main loop ###
        
    def update_tab(self,val):
        for i in range(len(self.filenames)):
            exec("self.remove_len1%d = int(self.remove_len1%d_slider.val)" %(i,i))
            exec("self.remove_len2%d = int(self.remove_len2%d_slider.val)" %(i,i))
        self.index = int(round(self.index_slider.val,0))
        self.NMAX  = int(round(self.nmax_slider.val,0))
        self.Y0 = 0
        self.update_tabs()
        self.norm_fig()
        self.fig.canvas.draw()
        
    def update_tabs(self):
        for i in range(len(self.filenames)):
            exec("self.folded_data_orig3%d = array(self.folded_data_orig%d[:,self.remove_len1%d:-self.remove_len2%d])" %(i,i,i,i))
            exec("self.process_data(self.shear,%d)" %i)
            exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i))
            
    def process_data(self,val,i):
        """ Redress data in the space/time diagram """
        exec("dd = self.folded_data_orig2%d.copy()" %i)
        for k in range(0,dd.shape[0]):
            exec("dd[k,:] = roll(self.folded_data_orig2%d[k,:], int(k*val))" %i)
        exec("self.folded_data_orig%d = dd" %i)
        
    def norm_fig(self):
        if len(self.filenames)==1:
            LIM_MIN = self.folded_data0.mean(1).min()
            LIM_MAX = self.folded_data0.mean(1).max()
        elif len(self.filenames)==2:
            LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min())
            LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max())
        elif len(self.filenames)==3:
            LIM_MIN = min(self.folded_data0.mean(1).min(),self.folded_data1.mean(1).min(),self.folded_data2.mean(1).min())
            LIM_MAX = max(self.folded_data0.mean(1).max(),self.folded_data1.mean(1).max(),self.folded_data2.mean(1).max())
        self.axhh.clear()
        if not self.NORM:
            for i in range(len(self.filenames)):
                exec("self.ax%d.clear()" %i)
                exec("self.im%d = self.ax%d.imshow(self.folded_data%d,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.vmin, vmax=self.vmax)" %(i,i,i))
                exec("self.axh%d.clear()" %i)
                exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i))
                exec("self.axh%d.set_ylim(self.vmin, self.vmax)" %i)
                exec("self.axh%d.set_xlim(0, len(self.folded_data%d[0]))" %(i,i))
                exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(len(self.folded_data%d.mean(1))),self.COLORS[i])" %(i,i,i))
            self.axhh.set_ylim(0,self.max_index-self.index-1)
            self.axhh.set_xlim(self.vmin, self.vmax)
        else:
            for i in range(len(self.filenames)):
                exec("self.ax%d.clear()" %i)
                exec("self.im%d = self.ax%d.imshow(self.folded_data%d,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data%d.min(), vmax=self.folded_data%d.max())" %(i,i,i,i,i))
                exec("self.axh%d.clear()" %i)
                exec("self.hline%d, = self.axh%d.plot(self.folded_data%d[self.Y0,:])" %(i,i,i))
                exec("self.axh%d.set_ylim(self.folded_data%d.min(), self.folded_data%d.max())" %(i,i,i))
                exec("self.axh%d.set_xlim(0, len(self.folded_data%d[0]))" %(i,i))
                exec("self.hhline%d, = self.axhh.plot(self.folded_data%d.mean(1),arange(len(self.folded_data%d.mean(1))),self.COLORS[i])" %(i,i,i))
            self.axhh.set_ylim(0,len(self.folded_data0.mean(1)))
            self.axhh.set_xlim(LIM_MIN-1,LIM_MAX+1)
        self.fig.canvas.draw()

    ### BEGIN Slider actions ###
    def update_shear(self,val):
        self.shear = round(self.shear_slider.val,2)
        self.update_tabs()
        self.update_tab(0)
        self.norm_fig()
        self.fig.canvas.draw()

    def update_cut(self):
        for i in range(len(self.filenames)):
            exec("self.hline%d.set_ydata(self.folded_data%d[self.Y0,:])" %(i,i))
        self.fig.canvas.draw()
    ### END Slider actions ###
    
    ### BEGIN actions to the window ###
    def toggle_update(self):
            self.UPDATE = not(self.UPDATE)
            if self.UPDATE:
                gobject.idle_add(self.update_plot)
            self.color  = not(self.color)
            if not(self.color):
                self.patch.remove()
                self.axe_toggledisplay  = self.fig.add_axes([0.,0.,1.0,0.02])
                self.axe_toggledisplay.clear()
                self.plot_circle(0,0,2,fc='#FF4500')
                mpl.pyplot.axis('off')
                self.fig.canvas.draw()
            else:
                self.patch.remove()
                self.axe_toggledisplay  = self.fig.add_axes([0.,0.,1.0,0.02])
                self.axe_toggledisplay.clear()
                self.plot_circle(0,0,2,fc='#00FF7F')
                mpl.pyplot.axis('off')
                self.fig.canvas.draw()
                
    def keypress(self, event):
        if event.key == 'q': # eXit
            del event
            sys.exit()
        elif event.key=='c':
            del event
            self.CMAP = roll(self.CMAP,-1)
            self.norm_fig()
            self.fig.canvas.draw()
        elif event.key=='v':
            del event
            self.NORM = not(self.NORM)
            self.norm_fig()
        elif event.key == ' ': # play/pause
            self.toggle_update()
        elif event.key == 'w':
            if len(self.filenames)>=2:
                print 'Set REMOVE values of channel 1 to channel 2'
                self.remove_len11_slider.set_val(self.remove_len10)
                self.remove_len21_slider.set_val(self.remove_len20)
        elif event.key == 'x':
            if len(self.filenames)>=3:
                print 'Set REMOVE values of channel 1 to channel 3'
                self.remove_len12_slider.set_val(self.remove_len10)
                self.remove_len22_slider.set_val(self.remove_len20)
        elif event.key == 't':
            print 'Trying to smooth from index',self.index
            self.smooth_array()
            print 'Done MF'
        else:
            print 'Key '+str(event.key)+' not known'
    
    def mousemove(self, event):
        # called on each mouse motion to get mouse position
        if len(self.filenames)==1:
            if event.inaxes!=self.ax0: return
        elif len(self.filenames)==2:
            if event.inaxes!=self.ax0 and event.inaxes!=self.ax1: return
        elif len(self.filenames)==3:
            if event.inaxes!=self.ax0 and event.inaxes!=self.ax1 and event.inaxes!=self.ax2: return
        self.X0 = int(round(event.xdata,0))
        self.Y0 = int(round(event.ydata,0))
        self.update_cut()
    ### END actions to the window ###
    
    ### Divers useful functions ###
    def plot_circle(self,x,y,r,fc='r'):
        """Plot a circle of radius r at position x,y"""
        cir = mpl.patches.Circle((x,y), radius=r, fc=fc)
        self.patch = mpl.pyplot.gca().add_patch(cir)
    
    def smooth_array(self):
        if len(self.filenames)>=2 and self.UPDATE==False:
            self.fig2 = figure(5,figsize=(16,7))
            clf()
            for i in range(len(self.filenames)):
                exec("self.folded_data%d = self.folded_data_orig3%d[self.index:(self.index+self.NMAX)]" %(i,i))
            if len(self.filenames)==2:
                self.folded_data_retriggered1,self.folded_data_retriggered0 = self.trig_by_interpolation_pola(self.folded_data1,self.folded_data0)
                self.fig2ax0 = self.fig2.add_subplot(111)
                self.fig2ax0.clear()
                self.fig2ax0.imshow(self.folded_data_retriggered0,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered0.min(), vmax=self.folded_data_retriggered0.max())
            elif len(self.filenames)==3:
                self.folded_data_retriggered2,self.folded_data_retriggered0,self.folded_data_retriggered1 = self.trig_by_interpolation_pola(self.folded_data2,self.folded_data0,self.folded_data1)
                self.fig2ax1 = self.fig2.add_subplot(121)
                self.fig2ax1.clear()
                self.fig2ax1.imshow(self.folded_data_retriggered0,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered0.min(), vmax=self.folded_data_retriggered0.max())
                self.fig2ax2 = self.fig2.add_subplot(122)
                self.fig2ax2.clear()
                self.fig2ax2.imshow(self.folded_data_retriggered1,cmap=self.CMAP[0], interpolation='nearest', aspect='auto',origin='lower', vmin=self.folded_data_retriggered1.min(), vmax=self.folded_data_retriggered1.max())
            show(False)
            self.fig2.canvas.draw()
        else:
            print 'This function REQUIRES a trigger'
            
    
    def trig_by_interpolation_pola(self,data,pola1,pola2=None,thr=30,FACT=25,num_trig=0,DOWN=False):
        if DOWN:
            pos   = self.find_down(data,thr)
        else:
            pos   = self.find_up(data,thr)
        l     = data    # for a first trigg  ->  #array([data[pos[i]-100:pos[i]+100] for i in xrange(len(pos)-1)])
        lp1      = pola1
        if pola2 is not None: lp2 = pola2
        ll       = []
        trace    = []
        trace_p1 = []
        if pola2 is not None: trace_p2 = []
        for i in range(len(l)):
                b     = interpolate.interp1d(linspace(0,len(l[i]),len(l[i])),l[i])
                bp1   = interpolate.interp1d(linspace(0,len(lp1[i]),len(lp1[i])),lp1[i])
                if pola2 is not None: bp2   = interpolate.interp1d(linspace(0,len(lp2[i]),len(lp2[i])),lp2[i])
                xnew2 = linspace(0, len(l[i]),FACT*len(l[i]))
                xnew  = linspace(0, len(lp1[i]),FACT*len(lp1[i]))
                xnew3 = linspace(0, len(lp2[i]),FACT*len(lp2[i]))
                try:
                    temp2    = b(xnew2)
                    temp2_p1 = bp1(xnew)
                    if pola2 is not None: temp2_p2 = bp2(xnew3)
                    if DOWN:
                        temp = self.find_down(temp2,thr)
                    else:
                        temp = self.find_up(temp2,thr)
                    ll.append(temp[num_trig])     # If several downward event found for the trigger
                    trace.append(temp2)
                    trace_p1.append(temp2_p1)
                    if pola2 is not None: trace_p2.append(temp2_p2)
                except:
                    print '%d WARNING: Error repering => skiping a line'%i
        lll   = []
        lllp1 = []
        if pola2 is not None: lllp2 = []
        ll = array(ll)-array(ll).min()
        for i in range(len(ll)):
            lll.append(roll(trace[i],-ll[i]))
            lllp1.append(roll(trace_p1[i],-ll[i]))
            if pola2 is not None: lllp2.append(roll(trace_p2[i],-ll[i]))
        return (array(lll),array(lllp1),array(lllp2)) if pola2 is not None else (array(lll),array(lllp1))
    
    def find_down(self,d, threshold):
        digitized = zeros(shape=d.shape, dtype=uint8)
        digitized[where(d < threshold)] = 255
        derivative = digitized[1:]-digitized[0:-1]
        indices = where(derivative == 255)[0]
        return indices
    def find_up(self,d, threshold):
        digitized = zeros(shape=d.shape, dtype=uint8)
        digitized[where(d > threshold)] = 255
        derivative = digitized[1:]-digitized[0:-1]
        indices = where(derivative == 255)[0]
        return indices
        
    def declare_axis_1channel(self):
        self.ax0                   = axes([0.125+self.HORIZ_VAL,0.25,0.81,0.62])
        self.axh0                  = axes([0.125+self.HORIZ_VAL,0.05,0.81,0.15])
        self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.78,0.02])
        self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.78,0.02])
    def declare_axis_2channel(self):
        self.ax0                   = axes([0.125+self.HORIZ_VAL,0.25,0.395,0.62])
        self.ax1                   = axes([0.54+self.HORIZ_VAL,0.25,0.395,0.62])
        self.axh0                  = axes([0.125+self.HORIZ_VAL,0.05,0.395,0.15])
        self.axh1                  = axes([0.54+self.HORIZ_VAL,0.05,0.395,0.15])
        self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.37,0.02])
        self.remove_len11_sliderax = axes([0.54+self.HORIZ_VAL,0.91,0.37,0.02])
        self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.37,0.02])
        self.remove_len21_sliderax = axes([0.54+self.HORIZ_VAL,0.88,0.37,0.02])
    def declare_axis_3channel(self):
        self.ax0                   = axes([0.125+self.HORIZ_VAL,0.25,0.25,0.62])
        self.ax1                   = axes([0.405+self.HORIZ_VAL,0.25,0.25,0.62])
        self.ax2                   = axes([0.685+self.HORIZ_VAL,0.25,0.25,0.62])
        self.axh0                  = axes([0.125+self.HORIZ_VAL,0.05,0.25,0.15])
        self.axh1                  = axes([0.405+self.HORIZ_VAL,0.05,0.25,0.15])
        self.axh2                  = axes([0.685+self.HORIZ_VAL,0.05,0.25,0.15])
        self.remove_len10_sliderax = axes([0.125+self.HORIZ_VAL,0.91,0.25,0.02])
        self.remove_len11_sliderax = axes([0.405+self.HORIZ_VAL,0.91,0.25,0.02])
        self.remove_len12_sliderax = axes([0.685+self.HORIZ_VAL,0.91,0.25,0.02])
        self.remove_len20_sliderax = axes([0.125+self.HORIZ_VAL,0.88,0.25,0.02])
        self.remove_len21_sliderax = axes([0.405+self.HORIZ_VAL,0.88,0.25,0.02])
        self.remove_len22_sliderax = axes([0.685+self.HORIZ_VAL,0.88,0.25,0.02])
Exemple #55
0
class timeseriesViewer():
    """Class for tsview.py

    Example:
        cmd = 'tsview.py timeseries_ERA5_ramp_demErr.h5'
        obj = timeseriesViewer(cmd)
        obj.configure()
        obj.plot()
    """
    def __init__(self, cmd=None, iargs=None):
        if cmd:
            iargs = cmd.split()[1:]
        self.cmd = cmd
        self.iargs = iargs
        # print command line
        cmd = '{} '.format(os.path.basename(__file__))
        cmd += ' '.join(iargs)
        print(cmd)

        # figure variables
        self.figname_img = 'Cumulative Displacement Map'
        self.figsize_img = None
        self.fig_img = None
        self.ax_img = None
        self.cbar_img = None
        self.img = None

        self.ax_tslider = None
        self.tslider = None

        self.figname_pts = 'Point Displacement Time-series'
        self.figsize_pts = None
        self.fig_pts = None
        self.ax_pts = None
        return

    def configure(self):
        inps = cmd_line_parse(self.iargs)
        inps, self.atr = read_init_info(inps)
        # copy inps to self object
        for key, value in inps.__dict__.items():
            setattr(self, key, value)
        # input figsize for the point time-series plot
        self.figsize_pts = self.fig_size
        self.pts_marker = 'r^'
        self.pts_marker_size = 6.
        return

    def plot(self):
        # read 3D time-series
        self.ts_data, self.mask = read_timeseries_data(self)[0:2]

        # Figure 1 - Cumulative Displacement Map
        self.fig_img = plt.figure(self.figname_img, figsize=self.figsize_img)

        # Figure 1 - Axes 1 - Displacement Map
        self.ax_img = self.fig_img.add_axes([0.125, 0.25, 0.75, 0.65])
        img_data = np.array(
            self.ts_data[0][self.idx, :, :])  ####################
        img_data[self.mask == 0] = np.nan
        self.plot_init_image(img_data)

        # Figure 1 - Axes 2 - Time Slider
        self.ax_tslider = self.fig_img.add_axes([0.2, 0.1, 0.6, 0.07])
        self.plot_init_time_slider(init_idx=self.idx, ref_idx=self.ref_idx)
        self.tslider.on_changed(self.update_time_slider)

        # Figure 2 - Time Series Displacement - Point
        self.fig_pts, self.ax_pts = plt.subplots(num=self.figname_pts,
                                                 figsize=self.figsize_pts)
        if self.yx:
            d_ts = self.plot_point_timeseries(self.yx)

        # Output
        if self.save_fig:
            save_ts_plot(self.yx, self.fig_img, self.fig_pts, d_ts, self)

        # Final linking of the canvas to the plots.
        self.fig_img.canvas.mpl_connect('button_press_event',
                                        self.update_plot_timeseries)
        self.fig_img.canvas.mpl_connect('key_press_event', self.on_key_event)
        if self.disp_fig:
            vprint('showing ...')
            msg = '\n------------------------------------------------------------------------'
            msg += '\nTo scroll through the image sequence:'
            msg += '\n1) Move the slider, OR'
            msg += '\n2) Press left or right arrow key (if not responding, click the image and try again).'
            msg += '\n------------------------------------------------------------------------'
            vprint(msg)
            plt.show()
        return

    def plot_init_image(self, img_data):
        # prepare data
        if self.wrap:
            if self.disp_unit_img == 'radian':
                img_data *= self.range2phase
            img_data = ut.wrap(img_data, wrap_range=self.wrap_range)

        # Title and Axis Label
        disp_date = self.dates[self.idx].strftime('%Y-%m-%d')
        self.fig_title = 'N = {}, Time = {}'.format(self.idx, disp_date)

        # Initial Pixel of interest
        self.pts_yx = None
        self.pts_lalo = None
        if self.yx and self.yx != self.ref_yx:
            self.pts_yx = np.array(self.yx).reshape(-1, 2)
            if self.lalo:
                self.pts_lalo = np.array(self.lalo).reshape(-1, 2)

        # call view.py to plot
        self.img, self.cbar_img = view.plot_slice(self.ax_img, img_data,
                                                  self.atr, self)[2:4]
        return self.img, self.cbar_img

    def plot_init_time_slider(self, init_idx=-1, ref_idx=0):
        val_step = np.min(np.diff(self.yearList))
        val_min = self.yearList[0]
        val_max = self.yearList[-1]

        self.tslider = Slider(self.ax_tslider,
                              label='Years',
                              valinit=self.yearList[init_idx],
                              valmin=val_min,
                              valmax=val_max,
                              valstep=val_step)

        bar_width = val_step / 4.
        datex = np.array(self.yearList) - bar_width / 2.
        self.tslider.ax.bar(datex,
                            np.ones(len(datex)),
                            bar_width,
                            facecolor='black',
                            ecolor=None)
        self.tslider.ax.bar(datex[ref_idx],
                            1.,
                            bar_width * 3,
                            facecolor='crimson',
                            ecolor=None)

        # xaxis tick format
        if np.floor(val_max) == np.floor(val_min):
            digit = 10.
        else:
            digit = 1.
        self.tslider.ax.set_xticks(
            np.round(np.linspace(val_min, val_max, num=5) * digit) / digit)
        self.tslider.ax.xaxis.set_minor_locator(MultipleLocator(1. / 12.))
        self.tslider.ax.set_xlim([val_min, val_max])
        self.tslider.ax.set_yticks([])
        return self.tslider

    def update_time_slider(self, val):
        """Update Displacement Map using Slider"""
        idx = np.argmin(np.abs(np.array(self.yearList) - self.tslider.val))
        # update title
        disp_date = self.dates[idx].strftime('%Y-%m-%d')
        self.ax_img.set_title('N = {n}, Time = {t}'.format(n=idx, t=disp_date),
                              fontsize=self.font_size)
        # read data
        data_img = np.array(self.ts_data[0][idx, :, :])
        data_img[self.mask == 0] = np.nan
        if self.wrap:
            if self.disp_unit_img == 'radian':
                data_img *= self.range2phase
            data_img = ut.wrap(data_img, wrap_range=self.wrap_range)
        # update data
        self.img.set_data(data_img)
        self.idx = idx
        self.fig_img.canvas.draw()
        return

    def plot_point_timeseries(self, yx):
        """Plot point displacement time-series at pixel [y, x]
        Parameters: yx : list of 2 int
        Returns:    d_ts : 2D np.array in size of (num_date, num_file)
        """
        self.ax_pts.cla()

        # plot scatter in different size for different files
        num_file = len(self.ts_data)
        if num_file <= 2: ms_step = 4
        elif num_file == 3: ms_step = 3
        elif num_file == 4: ms_step = 2
        elif num_file >= 5: ms_step = 1

        d_ts = []
        y = yx[0] - self.pix_box[1]
        x = yx[1] - self.pix_box[0]
        for i in range(num_file - 1, -1, -1):
            # get displacement data
            d_tsi = self.ts_data[i][:, y, x]
            if self.zero_first:
                d_tsi -= d_tsi[self.zero_idx]
            d_ts.append(d_tsi)

            # get plot parameter - namespace ppar
            ppar = argparse.Namespace()
            ppar.label = self.file_label[i]
            ppar.ms = self.marker_size - ms_step * (num_file - 1 - i)
            ppar.mfc = pp.mplColors[num_file - 1 - i]
            if self.mask[y, x] == 0:
                ppar.mfc = 'gray'
            if self.offset:
                d_tsi += self.offset * (num_file - 1 - i)

            # plot
            if not np.all(np.isnan(d_tsi)):
                self.ax_pts = self.ts_plot_func(self.ax_pts, d_tsi, self, ppar)

        # axis format
        self.ax_pts = _adjust_ts_axis(self.ax_pts, self)
        title_ts = _get_ts_title(yx[0], yx[1], self.coord)
        if self.mask[y, x] == 0:
            title_ts += ' (masked out)'
        if self.disp_title:
            self.ax_pts.set_title(title_ts, fontsize=self.font_size)
        if self.tick_right:
            self.ax_pts.yaxis.tick_right()
            self.ax_pts.yaxis.set_label_position("right")

        # legend
        if len(self.ts_data) > 1:
            self.ax_pts.legend()

        # Print to terminal
        vprint('\n---------------------------------------')
        vprint(title_ts)
        float_formatter = lambda x: [float('{:.2f}'.format(i)) for i in x]
        vprint(float_formatter(d_ts[0]))

        if not np.all(np.isnan(d_ts[0])):
            # stat info
            vprint('displacement range: [{:.2f}, {:.2f}] {}'.format(
                np.nanmin(d_ts[0]), np.nanmax(d_ts[0]), self.disp_unit))

            # estimate (print) slope
            estimate_slope(d_ts[0],
                           self.yearList,
                           ex_flag=self.ex_flag,
                           disp_unit=self.disp_unit)

            # update figure
            self.fig_pts.canvas.draw()
        return d_ts

    def update_plot_timeseries(self, event):
        """Event function to get y/x from button press"""
        if event.inaxes == self.ax_img:
            # get row/col number
            if self.fig_coord == 'geo':
                y, x = self.coord.geo2radar(event.ydata,
                                            event.xdata,
                                            print_msg=False)[0:2]
            else:
                y, x = int(event.ydata + 0.5), int(event.xdata + 0.5)

            # plot time-series displacement
            self.plot_point_timeseries((y, x))
        return

    def on_key_event(self, event):
        """Slide images with left/right key on keyboard"""
        if event.inaxes and event.inaxes.figure == self.fig_img:
            idx = None
            if event.key == 'left':
                idx = max(self.idx - 1, 0)
            elif event.key == 'right':
                idx = min(self.idx + 1, self.num_date - 1)

            if idx is not None and idx != self.idx:
                # update title
                disp_date = self.dates[idx].strftime('%Y-%m-%d')
                self.ax_img.set_title('N = {n}, Time = {t}'.format(
                    n=idx, t=disp_date),
                                      fontsize=self.font_size)

                # read data
                data_img = np.array(self.ts_data[0][idx, :, :])
                data_img[self.mask == 0] = np.nan
                if self.wrap:
                    if self.disp_unit_img == 'radian':
                        data_img *= self.range2phase
                    data_img = ut.wrap(data_img, wrap_range=self.wrap_range)

                # update
                self.img.set_data(data_img)  # update image
                self.tslider.set_val(self.yearList[idx])  # update slider
                self.idx = idx
                self.fig_img.canvas.draw()
        return
Exemple #56
0
def profileview(model):
    plt.rc('font', size=8)

    fig, axs = plt.subplots(3, sharex=True)
    axs[0].set_title('protonic potential (V)')
    axs[1].set_title('oxygen molar fraction')
    axs[2].set_title('current density (A/m2)')
    axs[2].set_xlabel('distance from membrane (um)')
    fig.subplots_adjust(right=0.45)

    # the mask is assumed to account for boundary values
    mask = ~(model.gdl | model.membrane)
    # we denote the distance from the membrane wall x
    x = model.distance_from_membrane[mask] * 1E6
    # the polarization curve is permanent and unique
    pcax = fig.add_axes([.5, .2, .4, .7])
    pcax.set_title('polarization curve')
    pcax.yaxis.tick_right()
    pcax.yaxis.set_label_position("right")
    pcax.set_xlabel('geometric current density (A/cm**2)')
    pcax.set_ylabel('voltage at gdl (V)')
    polcurve, = pcax.plot([], [], 'ko-')

    def update(V):
        # try networks generate l==1
        model.resolve(V, 263, flood=False)

        i = model.current_history[-1][mask]
        I = i.sum() / model.face_area / 1000
        reading = (I, V)

        insert_point(reading, polcurve)
        pcax.relim()
        pcax.autoscale_view()

        for ax, yh in zip(axs,
            [   model.proton_history,
                model.oxygen_history,
                model.current_history,
            ]):
            y = yh[-1]
            update_ax(ax, x, y)
        fig.canvas.draw()

    slider = Slider(label='V',
        ax=fig.add_axes([.5, .05, .4, .03]),
        valmin=0,
        valmax=1.2,
    )

    slider.on_changed(update)
    # generate an undersampled polcurve
    for V in np.linspace(0.05, 1.0, 10):
        slider.set_val(V)
    slider.set_val(0.65)

    # some extra interactivity
    fig.canvas.mpl_connect('button_press_event',
        lambda event: slider.set_val(event.ydata)
        if pcax is event.inaxes else None)
    plt.show()
Exemple #57
0
def view_patches_bar(Yr, A, C, b, f, d1, d2, YrA=None, img=None):
    """view spatial and temporal components interactively

     Parameters:
     -----------
     Yr:    np.ndarray
            movie in format pixels (d) x frames (T)

     A:     sparse matrix
                matrix of spatial components (d x K)

     C:     np.ndarray
                matrix of temporal components (K x T)

     b:     np.ndarray
                spatial background (vector of length d)

     f:     np.ndarray
                temporal background (vector of length T)

     d1,d2: np.ndarray
                frame dimensions

     YrA:   np.ndarray
                 ROI filtered residual as it is given from update_temporal_components
                 If not given, then it is computed (K x T)

     img:   np.ndarray
                background image for contour plotting. Default is the image of all spatial components (d1 x d2)

    """

    pl.ion()
    if 'csc_matrix' not in str(type(A)):
        A = csc_matrix(A)
    if 'array' not in str(type(b)):
        b = b.toarray()

    nr, T = C.shape
    nb = f.shape[0]
    nA2 = np.sqrt(np.array(A.power(2).sum(axis=0))).squeeze()

    if YrA is None:
        Y_r = spdiags(old_div(1, nA2), 0, nr, nr) * (A.T.dot(Yr) -
                                                     (A.T.dot(b)).dot(f) - (A.T.dot(A)).dot(C)) + C
    else:
        Y_r = YrA + C

    if img is None:
        img = np.reshape(np.array(A.mean(axis=1)), (d1, d2), order='F')

    fig = pl.figure(figsize=(10, 10))

    axcomp = pl.axes([0.05, 0.05, 0.9, 0.03])

    ax1 = pl.axes([0.05, 0.55, 0.4, 0.4])
    ax3 = pl.axes([0.55, 0.55, 0.4, 0.4])
    ax2 = pl.axes([0.05, 0.1, 0.9, 0.4])

    s_comp = Slider(axcomp, 'Component', 0, nr + nb - 1, valinit=0)
    vmax = np.percentile(img, 95)

    def update(val):
        i = np.int(np.round(s_comp.val))
        print(('Component:' + str(i)))

        if i < nr:

            ax1.cla()
            imgtmp = np.reshape(A[:, i].toarray(), (d1, d2), order='F')
            ax1.imshow(imgtmp, interpolation='None', cmap=pl.cm.gray, vmax=np.max(imgtmp)*0.5)
            ax1.set_title('Spatial component ' + str(i + 1))
            ax1.axis('off')

            ax2.cla()
            ax2.plot(np.arange(T), Y_r[i], 'c', linewidth=3)
            ax2.plot(np.arange(T), C[i], 'r', linewidth=2)
            ax2.set_title('Temporal component ' + str(i + 1))
            ax2.legend(labels=['Filtered raw data', 'Inferred trace'])

            ax3.cla()
            ax3.imshow(img, interpolation='None', cmap=pl.cm.gray, vmax=vmax)
            imgtmp2 = imgtmp.copy()
            imgtmp2[imgtmp2 == 0] = np.nan
            ax3.imshow(imgtmp2, interpolation='None',
                       alpha=0.5, cmap=pl.cm.hot)
            ax3.axis('off')
        else:
            ax1.cla()
            bkgrnd = np.reshape(b[:, i - nr], (d1, d2), order='F')
            ax1.imshow(bkgrnd, interpolation='None')
            ax1.set_title('Spatial background ' + str(i + 1 - nr))
            ax1.axis('off')

            ax2.cla()
            ax2.plot(np.arange(T), np.squeeze(np.array(f[i - nr, :])))
            ax2.set_title('Temporal background ' + str(i + 1 - nr))

    def arrow_key_image_control(event):

        if event.key == 'left':
            new_val = np.round(s_comp.val - 1)
            if new_val < 0:
                new_val = 0
            s_comp.set_val(new_val)

        elif event.key == 'right':
            new_val = np.round(s_comp.val + 1)
            if new_val > nr + nb:
                new_val = nr + nb
            s_comp.set_val(new_val)
        else:
            pass

    s_comp.on_changed(update)
    s_comp.set_val(0)
    fig.canvas.mpl_connect('key_release_event', arrow_key_image_control)
    pl.show()
Exemple #58
0
class Hist4D(object):
    def __init__(self, save_only=False):
        self.fig=None
        self.cubes_info=None
        self.slow=None
        self.shigh=None
        self.colormap=None
        self.save_only = save_only
    def draw_cubes(self,_axes, vals, edges):
        '''
        ax=Axes3D handle
        edges=matrix L+1xM+1xN+1 result of histogramdd
        vals=matrix LxMxN result of histogramdd
        colormap=color map to be matched with nonzero vals
        '''
        edx, edy, edz = np.meshgrid(edges[0], edges[1], edges[2])
        edx_rolled = np.roll(edx, -1, axis=1)
        edy_rolled = np.roll(edy, -1, axis=0)
        edz_rolled = np.roll(edz, -1, axis=2)
        edx_rolled = edx_rolled[:-1, :-1, :-1].ravel()
        edy_rolled = edy_rolled[:-1, :-1, :-1].ravel()
        edz_rolled = edz_rolled[:-1, :-1, :-1].ravel()
        edx = edx[:-1, :-1, :-1].ravel()
        edy = edy[:-1, :-1, :-1].ravel()
        edz = edz[:-1, :-1, :-1].ravel()
        vals = vals.ravel()
        vdraw_cube = np.vectorize(self.draw_cube, excluded='_axes')
        cubes_handles = vdraw_cube(_axes, edx[vals>0],
                                   edx_rolled[vals>0],
                                   edy[vals>0],
                                   edy_rolled[vals>0],
                                   edz[vals > 0],
                                   edz_rolled[vals > 0],
                                   vals[vals>0]/float(np.max(vals)))
        cubes_data = [a for a in zip(vals[vals>0],cubes_handles)]
        self.cubes_info=dict()
        for k, v in cubes_data:
            self.cubes_info[k] = self.cubes_info.get(k, ()) + tuple(v) #+(v,)

    def set_sliders(self,splot1,splot2):
        maxlim=max(self.cubes_info.keys())
        axcolor = 'lightgoldenrodyellow'
        #low_vis = self.fig.add_axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)
        #high_vis  = self.fig.add_axes([0.25, 0.15, 0.65, 0.03], axisbg=axcolor)
        self.slow = Slider(splot1,'low', 0.0, maxlim, valfmt='%0.0f')
        self.shigh = Slider(splot2, 'high', 0.0 , maxlim, valfmt='%0.0f')
        
        self.slow.on_changed(self.update)
        self.shigh.on_changed(self.update)
        self.slow.set_val(0)
        self.shigh.set_val(maxlim)
    def update(self,val):
        visible = [(k,v) for k, v in self.cubes_info.items() if k >
                   self.slow.val and k<=
                   self.shigh.val]
        invisible = [v for k, v in self.cubes_info.items() if k <=
                     self.slow.val or k>
                     self.shigh.val]
        for (k,sublist) in visible:
            for item in sublist:
                print item.set_alpha
                item.set_alpha(k)
        for item in [item for sublist in invisible for item in sublist]:
            item.set_alpha(0)
        total=[v for k,v in self.cubes_info.items()]
        self.fig.canvas.draw_idle()

    def draw_cube(self,_axes, x1_coord, x2_coord,
                  y1_coord, y2_coord,
                  z1_coord, z2_coord,
                  color_ind):
        '''
        draw a cube given cube limits and color
        '''
        _x_coord, _y_coord, _z_coord = np.meshgrid([x1_coord, x2_coord],
                                                   [y1_coord, y2_coord],
                                                   [z1_coord, z2_coord])
        tmp1 = np.concatenate((_x_coord.ravel()[None, :], _y_coord.ravel()[
            None, :], _z_coord.ravel()[None, :]), axis=0)
        tmp2 = tmp1.copy()
        tmp2[:, [0, 1]], tmp2[:, [6, 7]] = tmp2[
            :, [6, 7]].copy(), tmp2[:, [0, 1]].copy()
        tmp3 = tmp2.copy()
        tmp3[:, [0, 2]], tmp3[:, [5, 7]] = tmp3[
            :, [5, 7]].copy(), tmp3[:, [0, 2]].copy()
        points = np.concatenate((tmp1, tmp2, tmp3), axis=1)
        points = points.T.reshape(6, 4, 3)
        '''
        collection = Poly3DCollection(points,
                           facecolors=self.colormap(float(color_ind)),
                           linewidths=0
                            )
        _axes.add_collection3d(collection)
        return collection
        '''
        surf = []
        for count in range(6):
            surf.append(_axes.plot_surface(points[count, :, 0].reshape(2, 2),
                                          points[count, :, 1].reshape(2, 2),
                                          points[count, :, 2].reshape(2, 2),
                                          color=self.colormap(float(color_ind)),
                                          linewidth=0,
                                          antialiased=True,
                                          shade=False))
        return surf
    def array2cmap(self,X):
        N = X.shape[0]
        r = np.linspace(0., 1., N+1)
        r = np.sort(np.concatenate((r, r)))[1:-1]
        rd = np.concatenate([[X[i, 0], X[i, 0]] for i in xrange(N)])
        gr = np.concatenate([[X[i, 1], X[i, 1]] for i in xrange(N)])
        bl = np.concatenate([[X[i, 2], X[i, 2]] for i in xrange(N)])
        al = np.concatenate([[X[i, 3], X[i, 3]] for i in xrange(N)])
        rd = tuple([(r[i], rd[i], rd[i]) for i in xrange(2 * N)])
        gr = tuple([(r[i], gr[i], gr[i]) for i in xrange(2 * N)])
        bl = tuple([(r[i], bl[i], bl[i]) for i in xrange(2 * N)])
        al = tuple([(r[i], al[i], al[i]) for i in xrange(2 * N)])
        cdict = {'red': rd, 'green': gr, 'blue': bl, 'alpha': al}
        return colors.LinearSegmentedColormap('my_colormap', cdict)

    def draw_colorbar(self,_axes,unique_vals=None,cax=None):
        if unique_vals is None:
            unique_vals = np.linspace(0, 1, 1000)
        xmin, xmax = _axes.get_xlim()
        ymin, ymax = _axes.get_ylim()
        zmin, zmax = _axes.get_zlim()
        invis=_axes.scatter(unique_vals,
                           unique_vals,
                           c=np.arange(len(unique_vals)),
                           cmap=self.colormap)
        _axes.set_xlim([xmin,xmax])
        _axes.set_ylim([ymin,ymax])
        _axes.set_zlim([zmin,zmax])
        cbar=self.fig.colorbar(invis,ax=_axes,cax=cax,drawedges=False)
        cbar.set_ticks(np.linspace(0,np.size(unique_vals),5))
        if unique_vals is not None:
            cbar.set_ticklabels(np.around(np.linspace(0,np.max(unique_vals),5),2))

        invis.set_alpha(0)
    def create_opacity_colormap(self,principal_rgb_color, scale_size=256):
        '''
        Create opacity colormap based on one principal RGB color
        '''
        if np.any(principal_rgb_color > 1):
            raise Exception('principal_rgb_color values should  be in range [0,1]')
        opac_colormap = np.concatenate((np.tile(principal_rgb_color[None, :],
                                               (scale_size, 1))[:],
                                        np.linspace(0, 1, scale_size)[:, None]),
                                      axis=1)
        self.colormap=self.array2cmap(opac_colormap)

    def create_brightness_colormap(self,principal_rgb_color, scale_size):
        '''
        Create brightness colormap based on one principal RGB color
        '''
        if np.any(principal_rgb_color > 1):
            raise Exception('principal_rgb_color values should  be in range [0,1]')
        hsv_color = colors.rgb_to_hsv(principal_rgb_color)
        hsv_colormap = np.concatenate((np.tile(hsv_color[:-1][None, :], (scale_size, 1))[:],
                                       np.linspace(0, 1, scale_size)[:, None]),
                                      axis=1)
        self.colormap=self.array2cmap(colors.hsv_to_rgb(hsv_colormap))

    
    def draw(self,hist,edges,
             fig=None,gs=None,subplot=None,
             color=np.array([1,0,0]),all_axes=None):
        '''
        fig=figure handle
        gs= contiguous slice (or whole) of gridspec to host plot
        hist,edges=histogramdd output

        '''
        if fig is not None:
            self.fig=fig
        else:
            self.fig=plt.figure()

        if gs is None:
            gs = gridspec.GridSpec(50, 50)
        if all_axes is None:
            _axes = self.fig.add_subplot(gs[:-5,:45],projection='3d')
            cax=self.fig.add_subplot(gs[:-5,45:])
            ax1=self.fig.add_subplot(gs[-4:-2,:])
            ax2=self.fig.add_subplot(gs[-2:,:])
        else:
            _axes,cax,ax1,ax2=all_axes
            _axes.clear()
            cax.clear()
            ax1.clear()
            ax2.clear()
        #unique_hist=np.unique(hist)
        self.create_opacity_colormap(color)
        self.draw_cubes(_axes, hist, edges)
        self.draw_colorbar(_axes,cax=cax)
        _axes.set_xlim((edges[0].min(),edges[0].max()))
        _axes.set_ylim((edges[1].min(),edges[1].max()))
        _axes.set_zlim((edges[2].min(),edges[2].max()))
        self.fig.patch.set_facecolor('white')
        _axes.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
        _axes.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
        _axes.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
        if not self.save_only:
            self.set_sliders(ax1, ax2)
        return _axes,ax1,ax2