Exemple #1
0
 def record(self, event):
     if event.inaxes is None:
         return
     self.slit = Slit(self.axes, self.pixel_scale)
     self.slits.append(self.slit)
     self.cursor = widgets.Cursor(self.axes, useblit=False, color='red', linewidth=1)
     self.cid = self.fig.canvas.mpl_connect('button_press_event', self.get_click)
    def load_slit(self, event):
        files = glob.glob(self.savedir + '*.npz')

        if len(files) <= 0:
            print('There seems to be no save files in this directory.')
            return

        for i in range(len(files)):
            name = files[i]
            self.slit = Slit(self.axes, self.pixel_scale)
            self.slits.append(self.slit)
            data = np.load(name).items()
            self.slit.data = data[0][1]
            self.slit.curve_points = data[1][1]
            self.slit.length = data[2][0]
            self.slit.mpl_curve.append(
                self.axes.plot(self.slit.curve_points[:, 0],
                               self.slit.curve_points[:, 1]))
            self.axes.figure.canvas.draw()
            slit = np.zeros([len(self.range), self.nt, self.slit.res])
            for i, idx in enumerate(self.range):
                slit[i, :, :] = self.slit.get_slit_data(
                    self.data[:, idx, :, :], self.image_extent)
            slit = self.slit.get_slit_data(
                self.data[:, self.sliders[1]._slider.cval, :, :],
                self.image_extent)
            self.slit.length *= self.pixel_scale
            self.slit.data = slit
            self.plot_slits(slit)
    def load_slit(self, event):
        files_npz = glob.glob(self.savedir + '*.npz')
        files_npy = glob.glob(self.savedir + '*.npy')

        if len(files_npz) > 0 and len(files_npy) == 0:
            files = files_npz
            flag = 'npz'
        elif len(files_npy) > 0 and len(files_npz) == 0:
            files = files_npy
            flag = 'npy'
        else:
            print('Needs work and needs Ben Fogle')
            return

        for i in range(len(files)):
            name = files[i]
            self.slit = Slit(self.axes)
            self.slits.append(self.slit)
            if flag == 'npz':
                data = np.load(name).items()
                self.slit.data = data[0][1]
                self.slit.curve_points = data[1][1]
#                self.slit.distance = data[2][0]
            elif flag == 'npy':
                self.slit.curve_points[:,0], self.slit.curve_points[:,1] = zip(*np.load(name))
            self.slit.mpl_curve.append(self.axes.plot(self.slit.curve_points[:,0], self.slit.curve_points[:,1]))
            self.axes.figure.canvas.draw()
            slit = np.zeros([self.nlambda,self.nt,self.slit.res])
            for i in range(self.nlambda):
                slit[i,:,:] = self.slit.get_slit_data(self.data[:,i,:,:],self.image_extent)
            self.slit.distance *= self.pixel_scale
            self.slit.data = slit
            self.plot_slits(slit)
Exemple #4
0
    def load_slit(self, event):
        files = glob.glob(self.savedir + '*.npz')

        if len(files) <= 0:
            print('There seems to be no save files in this directory.')
            return

        for i in range(len(files)):
            name = files[i]
            self.slit = Slit(self.axes, self.pixel_scale)
            self.slits.append(self.slit)
            data = np.load(name).items()
            self.slit.data = data[0][1]
            self.slit.curve_points = data[1][1]
            self.slit.length = data[2][0]
            self.slit.mpl_curve.append(
                self.axes.plot(self.slit.curve_points[:, 0], self.slit.curve_points[:, 1]))
            self.axes.figure.canvas.draw()
            slit = np.zeros([len(self.range), self.nt, self.slit.res])
            for i, idx in enumerate(self.range):
                slit[i, :, :] = self.slit.get_slit_data(self.data[:, idx, :, :], self.image_extent)
            slit = self.slit.get_slit_data(self.data[:, self.sliders[1]._slider.cval, :, :],
                                           self.image_extent)
            self.slit.length *= self.pixel_scale
            self.slit.data = slit
            self.plot_slits(slit)
Exemple #5
0
class PlotInteractor(ImageAnimator):
    """
    A PlotInteractor.
    t,lambda,x,y

    Parameters
    ----------
    data: np.ndarray
        A 4D array

    pixel_scale: float
        Pixel scale for spatial axes

    save_dir: string
        dir to save slit files to

    axis_range: list or ndarray
        [min, max] pairs for each image axis and [min, max] pairs or arrays
        of values for each slider axis.
        Otherwise it just takes the shape and returns a non-physical index.
    """
    def __init__(self, data, pixel_scale, cadence, interop, savedir, **kwargs):
        all_axes = list(range(data.ndim))
        image_axes = [all_axes[i] for i in kwargs.get('image_axes', [-2, -1])]
        self.slider_axes = list(range(data.ndim))
        for x in image_axes:
            self.slider_axes.remove(x)

        if 'cmap' not in kwargs:
            kwargs['cmap'] = plt.get_cmap('gray')

        axis_range = [None, None,
                      [0, pixel_scale * data[0, 0, :, :].shape[0]],
                      [0, pixel_scale * data[0, 0, :, :].shape[1]]]
        axis_range = kwargs.pop('axis_range', axis_range)

        axis_range = self._sanitize_axis_range(axis_range, data)

        self.image_extent = list(itertools.chain.from_iterable([axis_range[i] for i in image_axes]))
        self.pixel_scale = pixel_scale
        self.cadence = cadence
        self.slits = []
        self.savedir = savedir
        self.nt = data.shape[0]
        self.nlambda = data.shape[1]
        self.interop = interop
        self.range = range(0, data.shape[1])
        
        button_labels, button_func = self.create_buttons()

        slider_functions = [self._updateimage]*len(self.slider_axes) + [self.update_range]*2 + [self.update_im_clim]*2
        slider_ranges = [axis_range[i] for i in self.slider_axes] + [range(0, self.nlambda)]*2 + [np.arange(0, 99.9)]*2

        ImageAnimator.__init__(self, data, axis_range=axis_range,
                               button_labels=button_labels,
                               button_func=button_func,
                               slider_functions=slider_functions,
                               slider_ranges=slider_ranges,
                               **kwargs)

        # Sets up the slit sliders
        self.sliders[2]._slider.set_val(self.nlambda)
        self.sliders[3]._slider.slidermax = self.sliders[2]._slider
        self.sliders[2]._slider.slidermin = self.sliders[3]._slider
        self.slider_buttons[3].set_visible(False)
        self.slider_buttons[2].set_visible(False)
        self.label_slider(3, "Start")
        self.label_slider(2, "End")

        # Sets up the intensity scaling sliders
        self.sliders[-2]._slider.set_val(100)
        self.sliders[-1]._slider.slidermax = self.sliders[-2]._slider
        self.sliders[-2]._slider.slidermin = self.sliders[-1]._slider
        self.slider_buttons[-1].set_visible(False)
        self.slider_buttons[-2].set_visible(False)
        self.axes.autoscale(False)
        self.label_slider(-1, "Min")
        self.label_slider(-2, "Max")

    def create_buttons(self):
        button_labels = ['Slit', 'Delete', 'Save', 'Load']
        button_func = [self.record, self.delete, self.save_slit, self.load_slit]

        return button_labels, button_func

    def update_im_clim(self, val, im, slider):
        if np.mean(self.data[self.frame_slice]) < 0:
            self.im.set_clim(np.min(self.data[self.frame_slice]) * (self.sliders[-1]._slider.val / 100),
                             np.max(self.data[self.frame_slice]) * (self.sliders[-2]._slider.val / 100))
        else:
            self.im.set_clim(np.max(self.data[self.frame_slice]) * (self.sliders[-1]._slider.val / 100),
                             np.max(self.data[self.frame_slice]) * (self.sliders[-2]._slider.val / 100))

    def update_range(self, val, im, slider):
        self.range = np.arange(int(self.sliders[3]._slider.val),int(self.sliders[2]._slider.val))
        if len(self.range) == 0:
            self.range = np.arange(int(self.sliders[3]._slider.val)-1,int(self.sliders[2]._slider.val)+1,1)

# =============================================================================
# Button Functions
# =============================================================================

    def delete(self, event):
        if not hasattr(self.slit, 'mpl_points'):
            print('You have not yet generated a curve to delete.')
        else:
            if len(self.slit.mpl_points) > 0 and len(self.slit.mpl_curve) > 0:
                self.slit.remove_all(self.slits)
                self.cid = None
                self.slits = []
        if hasattr(self, 'cursor'):
            self.fig.canvas.mpl_disconnect(self.cid)
            del self.cursor

    def record(self, event):
        if event.inaxes is None:
            return
        self.slit = Slit(self.axes, self.pixel_scale)
        self.slits.append(self.slit)
        self.cursor = widgets.Cursor(self.axes, useblit=False, color='red', linewidth=1)
        self.cid = self.fig.canvas.mpl_connect('button_press_event', self.get_click)

    def save_slit(self, event, filename=False):
        if not hasattr(self.slit, 'mpl_points'):
            print('There is no slit to save.')
        else:
            names = ['curve_points', 'slit_data', 'length']
            if not filename:
                filename = str(datetime.datetime.now())
                np.savez(self.savedir + filename, names, self.slit.curve_points, self.slit.data, self.slit.length)

    def load_slit(self, event):
        files = glob.glob(self.savedir + '*.npz')

        if len(files) <= 0:
            print('There seems to be no save files in this directory.')
            return

        for i in range(len(files)):
            name = files[i]
            self.slit = Slit(self.axes, self.pixel_scale)
            self.slits.append(self.slit)
            data = np.load(name).items()
            self.slit.data = data[0][1]
            self.slit.curve_points = data[1][1]
            self.slit.length = data[2][0]
            self.slit.mpl_curve.append(self.axes.plot(self.slit.curve_points[:, 0], self.slit.curve_points[:, 1]))
            self.axes.figure.canvas.draw()
            slit = np.zeros([len(self.range), self.nt, self.slit.res])
            for i, idx in enumerate(self.range):
                slit[i, :, :] = self.slit.get_slit_data(self.data[:, idx, :, :], self.image_extent)                
            slit = self.slit.get_slit_data(self.data[:, self.sliders[1]._slider.cval, :, :], self.image_extent)
            self.slit.length *= self.pixel_scale
            self.slit.data = slit
            self.plot_slits(slit)

# =============================================================================
# Figure Callbacks
# =============================================================================

    def get_click(self, event):
        if event.inaxes is not None:
            if event.inaxes is self.axes and event.button == 1:
                self.slit.add_point(event.xdata, event.ydata)
            elif event.inaxes is self.axes and event.button == 3:
                self.slit.remove_point()
            elif event.inaxes is self.axes and event.button == 2:
                self.slit.create_curve(self.interop)
                slit = np.zeros([len(self.range), self.nt, self.slit.res])
                for i, idx in enumerate(self.range):
                    slit[i, :, :] = self.slit.get_slit_data(self.data[:, idx, :, :], self.image_extent)
                self.slit.length *= self.pixel_scale
                self.slit.data = slit
                self.plot_slits(slit)
                self.fig.canvas.mpl_disconnect(self.cid)
                self.cid = None
            else:
                print('Click a real mouse button')

    def plot_slits(self, slit):
        extent = [0, self.nt*self.cadence, 0, self.slit.length]
        fig, axes = plt.subplots(nrows=slit.shape[0], ncols=1,
                                 sharex=True, sharey=True, figsize=(10, 18))
        if slit.shape[0] == 1:
            axes = [axes]

        for i in range(slit.shape[0]):
            loc_mean = slit[i, :, :].T/np.max(np.abs(slit[i, :, :].T))
            axes[i].imshow(loc_mean[:, :], origin='lower', interpolation='nearest',
                           cmap=plt.get_cmap('Greys_r'), extent=extent, aspect='auto')
            axes[i].set_xlim(0, extent[1])
            axes[i].set_ylim(0, extent[3])
        plt.xlabel('Time (seconds)')
        plt.ylabel('Length along slit (arcsecs)')
        fig.tight_layout()
        fig.subplots_adjust(hspace=0, wspace=0)
        fig.show()
class PlotInteractor(ImageAnimator):
    """
    A PlotInteractor.
    t,lambda,x,y

    Parameters
    ----------
    data: np.ndarray
        A 4D array

    pixel_scale: float
        Pixel scale for spatial axes

    save_dir: string
        dir to save slit files to

    axis_range: list or ndarray
        [min, max] pairs for each image axis and [min, max] pairs or arrays
        of values for each slider axis.
    """
    def __init__(self, data, pixel_scale, savedir, **kwargs):
        all_axes = list(range(data.ndim))
        image_axes = [all_axes[i] for i in kwargs.get('image_axes', [-2,-1])]
        self.slider_axes = list(range(data.ndim))
        for x in image_axes:
            self.slider_axes.remove(x)

        axis_range = [None,None,
                      [0, pixel_scale * data[0,0,:,:].shape[0]],
                      [0, pixel_scale * data[0,0,:,:].shape[1]]]
        axis_range = kwargs.pop('axis_range', axis_range)

        axis_range = self._sanitize_axis_range(axis_range, data)

        self.image_extent = list(itertools.chain.from_iterable([axis_range[i] for i in image_axes]))
        self.pixel_scale = pixel_scale
        self.r_diff = []
        self.slits = []
        self.savedir = savedir
        self.nlambda = data.shape[1]
        self.nt = data.shape[0]

        button_labels, button_func = self.create_buttons()

        slider_functions = [self._updateimage]*len(self.slider_axes) + [self.update_im_clim]*2
        slider_ranges = [axis_range[i] for i in self.slider_axes] + [np.arange(0,99.9)]*2
        
        ImageAnimator.__init__(self, data, axis_range=axis_range,
                                  button_labels=button_labels,
                                  button_func=button_func,
                                  slider_functions=slider_functions,
                                  slider_ranges=slider_ranges,
                                  **kwargs)

        self.sliders[-2]._slider.set_val(100)
        self.sliders[-1]._slider.slidermax = self.sliders[-2]._slider
        self.sliders[-2]._slider.slidermin = self.sliders[-1]._slider
        self.slider_buttons[-1].set_visible(False)
        self.slider_buttons[-2].set_visible(False)
        self.axes.autoscale(False)
        self.label_slider(-1, "Min")
        self.label_slider(-2, "Max")

    def create_buttons(self):
        button_labels = ['Slit', 'Delete', 'Save', 'Load']
        button_func = [self.record, self.delete, self.save_slit, self.load_slit]

        return button_labels, button_func

    def update_im_clim(self, val, im, slider):
        if np.mean(self.data[self.frame_slice]) < 0:
            self.im.set_clim(np.min(self.data[self.frame_slice]) * (self.sliders[-1]._slider.val / 100),
                         np.max(self.data[self.frame_slice]) * (self.sliders[-2]._slider.val / 100))
        else:
            self.im.set_clim(np.max(self.data[self.frame_slice]) * (self.sliders[-1]._slider.val / 100),
                         np.max(self.data[self.frame_slice]) * (self.sliders[-2]._slider.val / 100))
#==============================================================================
# Button Functions
#==============================================================================

    def delete(self, event):
        if not hasattr(self.slit, 'mpl_points'):
            print('You have not yet generated a curve to delete.')
        else:
            if len(self.slit.mpl_points) > 0 and len(self.slit.mpl_curve) > 0:
                self.slit.remove_all(self.slits)
                self.cid = None
                self.slits = []
        if hasattr(self, 'cursor'):
            self.fig.canvas.mpl_disconnect(self.cid)
            del self.cursor

    def record(self, event):
        if event.inaxes is None: return
        self.slit = Slit(self.axes)
        self.slits.append(self.slit)
        self.cursor = widgets.Cursor(self.axes, useblit=False, color='red', linewidth=1)
        self.cid = self.fig.canvas.mpl_connect('button_press_event', self.get_click)

    def save_slit(self, event, filename=False):
        if not hasattr(self.slit, 'mpl_points'):
            print('SAVE BEN FOGLE, SAVE THE SLIT.')
        else:
            names = ['curve_points', 'slit_data', 'distance']
            if not filename:
                filename = str(datetime.datetime.now())
            if self.r_diff:
                np.savez(self.savedir + filename, names + ['run_diff'],
                         self.slit.curve_points, self.slit.data, self.slit.distance, self.slit.data_run)
            else:
                np.savez(self.savedir + filename, names,
                         self.slit.curve_points, self.slit.data, self.slit.distance)

    def load_slit(self, event):
        files_npz = glob.glob(self.savedir + '*.npz')
        files_npy = glob.glob(self.savedir + '*.npy')

        if len(files_npz) > 0 and len(files_npy) == 0:
            files = files_npz
            flag = 'npz'
        elif len(files_npy) > 0 and len(files_npz) == 0:
            files = files_npy
            flag = 'npy'
        else:
            print('Needs work and needs Ben Fogle')
            return

        for i in range(len(files)):
            name = files[i]
            self.slit = Slit(self.axes)
            self.slits.append(self.slit)
            if flag == 'npz':
                data = np.load(name).items()
                self.slit.data = data[0][1]
                self.slit.curve_points = data[1][1]
#                self.slit.distance = data[2][0]
            elif flag == 'npy':
                self.slit.curve_points[:,0], self.slit.curve_points[:,1] = zip(*np.load(name))
            self.slit.mpl_curve.append(self.axes.plot(self.slit.curve_points[:,0], self.slit.curve_points[:,1]))
            self.axes.figure.canvas.draw()
            slit = np.zeros([self.nlambda,self.nt,self.slit.res])
            for i in range(self.nlambda):
                slit[i,:,:] = self.slit.get_slit_data(self.data[:,i,:,:],self.image_extent)
            self.slit.distance *= self.pixel_scale
            self.slit.data = slit
            self.plot_slits(slit)
            
#==============================================================================
# Figure Callbacks
#==============================================================================

    def get_click(self, event):
        if not event.inaxes is None:
            if event.inaxes is self.axes and event.button == 1:
                self.slit.add_point(event.xdata,event.ydata)
            elif event.inaxes is self.axes and event.button == 3:
                self.slit.remove_point()
            elif event.inaxes is self.axes and event.button == 2:
                self.slit.create_curve()
                slit = np.zeros([self.nlambda,self.nt,self.slit.res])
                for i in range(self.nlambda):
                    slit[i,:,:] = self.slit.get_slit_data(self.data[:,i,:,:],self.image_extent)
#                profiler.stop()
#                print(profiler.output_text(unicode=True, color=True))
                self.slit.distance *= self.pixel_scale
                self.slit.data = slit
                self.plot_slits(slit)
                self.fig.canvas.mpl_disconnect(self.cid)
                self.cid = None
            else:
                print('Click a real mouse button')

    def plot_slits(self, slit, r_diff=False):
        extent = [0, self.nt, 0 , self.slit.distance]
        self.r_diff = r_diff

        if r_diff:
            fig, axes = plt.subplots(nrows=self.nlambda, ncols=2,
                                     sharex=True, sharey=True, figsize = (10,8))
        else:
            fig, axes = plt.subplots(nrows=self.nlambda, ncols=1,
                                     sharex=True, sharey=True, figsize = (6,9))
        if self.nlambda == 1 and not r_diff:
            axes = [axes]

        for i in range(0, self.nlambda):
            if r_diff:
                rundiff = self.slit.get_run_diff(slit[i,:,:])
                axes[1].imshow(rundiff[:,:].T/np.max(np.abs(rundiff[:,:].T)), origin='lower',
                                interpolation='spline36',
                                 cmap=plt.get_cmap('Greys_r'), extent = extent,
                                    aspect='auto')
                axes[0].imshow(slit[i,:,:].T/np.max(np.abs(slit[i,:,:].T)), origin='lower',
                                    interpolation='spline36',
                                    cmap=plt.get_cmap('Greys_r'), extent = extent,
                                    aspect='auto')
            else:
                loc_mean = slit[i,:,:].T/np.max(np.abs(slit[i,:,:].T))

                axes[i].imshow(loc_mean[:,:], origin='lower',
                                interpolation='spline36',
                                cmap=plt.get_cmap('Greys_r'), extent = extent,
                                    aspect='auto')
                axes[i].set_xlim(0,extent[1])
                axes[i].set_ylim(0,extent[3])
        fig.tight_layout()
        fig.subplots_adjust(hspace=0, wspace=0)
        fig.show()