Ejemplo n.º 1
0
class CubeWidget(ImageWidget):
    def __init__(self,
                 hdu=None,
                 im=None,
                 wcs=None,
                 show_rainbow=True,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        # self._4d_idx = 0  # Lock 4th dim to this for now

        if hdu is not None:
            self.im = hdu.data
            self.wcs = WCS(hdu.header)
        elif im is not None:
            self.im = im
            self.wcs = wcs
        else:
            print("Provide a 3D HDU or image and wcs object")

        self.nddata = NDData(self.im, wcs=self.wcs)
        self.load_nddata(self.nddata, n=0)

        # get wave info:

        self.dwave = self.wcs.wcs.cdelt[2]
        self.wave_start = self.wcs.wcs.crval[2]
        self.nwave = np.shape(self.im)[0]
        self.wave_end = self.wave_start + self.nwave * self.dwave
        self.show_rainbow = show_rainbow

        #zscale = ZScaleInterval(contrast=0.3, krej=2.5)
        #vmin, vmax = zscale.get_limits(values=self.im)

        self.cuts = 'stddev'

        self.wave_widget = widgets.IntSlider(
            min=self.wave_start,
            max=self.wave_end,
            step=self.dwave,
            value=4540,
            continuous_update=False,
        )

        self.slider = widgets.interactive(self.show_slice,
                                          wave=self.wave_widget)

        self.animate_button = widgets.Button(
            description="Scan Cube",
            disabled=False,
            button_style="success",
            tooltip="Click this to scan in wavelength dimension",
        )

        # For line profile plot
        self._cur_islice = None
        self._cur_ix = None
        self._cur_iy = None
        self.line_out = widgets.Output()
        self.line_plot = None
        self.plot_xlabel = "Wavelength (A)"
        self.plot_ylabel = "Flux Density"  # (10^-17 erg cm-2 s-1 arcsec-2"

        if self.show_rainbow:
            self.set_rainbow()

        # If plot shows, rerun cell to hide it.
        ax = plt.gca()
        self.line_plot = ax

        self.scan = widgets.Play(
            value=self.wave_start,
            min=self.wave_start,
            max=self.wave_end,
            step=self.dwave,
            #            interval=500,
            description="Scan Cube",
            disabled=False,
        )

        widgets.jslink((self.scan, "value"), (self.wave_widget, "value"))

        left_panel = widgets.VBox(
            [widgets.HBox([self.wave_widget, self.scan]), self])

        display(widgets.HBox([left_panel, self.line_out]))

    def load_nddata(self, nddata, n=0):  # update this for wavelength later

        self.image = AstroImage()
        self.image.load_nddata(nddata, naxispath=[n])
        self._viewer.set_image(self.image)

    def _mouse_click_cb(self, viewer, event, data_x, data_y):

        self._cur_ix = int(round(data_x))
        self._cur_iy = int(round(data_y))
        self.plot_line_profile()
        # Ensure only active marker is shown
        self.reset_markers()

        if self._cur_ix is not None:
            mrk_tab = Table(names=["x", "y"])
            mrk_tab.add_row([self._cur_ix, self._cur_iy])
            self.marker = {"color": "red", "radius": 1, "type": "circle"}
            self.add_markers(mrk_tab)
        # self.reset_markers()

        super()._mouse_click_cb(viewer, event, data_x, data_y)

    def plot_line_profile(self):
        if self.line_plot is None or self._cur_ix is None or self._cur_iy is None:
            return

        #        image = self._viewer.get_image()
        if self.image is None:
            return

        with self.line_out:
            mddata = self.image.get_mddata()
            self.line_plot.clear()

            self.wavelengths = (self.wave_start +
                                self.dwave * np.arange(self.nwave))

            try:
                self.spectrum = mddata[:, self._cur_iy, self._cur_ix]
            except IndexError:
                return

            self.line_plot.plot(
                self.wavelengths,
                self.spectrum,
                color="black",
                linewidth=1.2,
            )

            if self._cur_islice is not None:
                y = mddata[self._cur_islice, self._cur_iy, self._cur_ix]
                x = self.wave_start + self.dwave * self._cur_islice
                self.line_plot.axvline(x=x, color="r", linewidth=1)

            # self.line_plot.set_title(f'X={self._cur_ix + 1} Y={self._cur_iy + 1}')
            self.line_plot.set_xlabel(self.plot_xlabel)
            self.line_plot.set_ylabel(self.plot_ylabel)
            self.line_plot.set_xlim(self.wave_start, self.wave_end)

            if self.show_rainbow:
                y2 = np.linspace(np.min(self.spectrum), np.max(self.spectrum),
                                 100)
                X, Y = np.meshgrid(self.wavelengths, y2)

                extent = (self.wave_start, self.wave_end, np.min(y2),
                          np.max(y2))

                self.line_plot.imshow(X,
                                      clim=self.clim,
                                      extent=extent,
                                      cmap=self.spectralmap,
                                      aspect='auto')

                self.line_plot.fill_between(self.wavelengths,
                                            self.spectrum,
                                            np.max(y2),
                                            color='w')

            clear_output(wait=True)
            display(self.line_plot.figure)

    def set_rainbow(self):
        self.clim = (self.wave_start, self.wave_end)
        norm = plt.Normalize(*self.clim)
        wl = np.arange(self.clim[0], self.clim[1] + 1, 2)
        colorlist = list(zip(norm(wl), [wavelength_to_rgb(w) for w in wl]))
        self.spectralmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            "spectrum", colorlist)

    def image_show_slice(self, n):
        # image = self._viewer.get_image()
        self.image.set_naxispath([n])
        self._viewer.redraw(whence=0)
        self._cur_islice = n

    def show_slice(self, wave):

        n = int((wave - self.wave_start) / self.dwave)

        self.image_show_slice(n - 1)
        self.plot_line_profile()