예제 #1
0
    def _resample_to_new_frame(self, adinputs=None, frame=None, order=3,
                               conserve=True, output_shape=None, origin=None,
                               clean_data=False, process_objcat=False):
        """
        This private method resamples a number of AstroData objects to a
        CoordinateFrame they share. It is basically just a wrapper for the
        transform.resample_from_wcs() method that creates an
        appropriately-sized output based on the complete set of input
        AstroData objects.

        Parameters
        ----------
        frame: str
            name of CoordinateFrame to be resampled to
        order: int (0-5)
            order of interpolation (0=nearest, 1=linear, etc.)
        output_shape : tuple/None
            shape of output image (if None, calculate and use shape that
            contains all resampled inputs)
        origin: tuple/None
            location of origin in reampled output (i.e., data to the left of
            or below this will be cut)
        clean_data : bool
            replace bad pixels with a ring median of their values to avoid
            ringing if using a high-order interpolation?
        process_objcat : bool
            update (rather than delete) the OBJCAT?
        """
        log = self.log

        if clean_data:
            self.applyDQPlane(adinputs, replace_flags=DQ.not_signal ^ DQ.no_data,
                              replace_value="median", inner=3, outer=5)

        if output_shape is None or origin is None:
            all_corners = np.concatenate([transform.get_output_corners(
                ext.wcs.get_transform(ext.wcs.input_frame, frame),
                input_shape=ext.shape) for ad in adinputs for ext in ad], axis=1)
            if origin is None:
                origin = tuple(np.ceil(min(corners)) for corners in all_corners)
            if output_shape is None:
                output_shape = tuple(int(np.floor(max(corners)) - np.ceil(min(corners)) + 1)
                                     for corners in all_corners)

        log.stdinfo("Output image will have shape "+repr(output_shape[::-1]))
        adoutputs = []
        for ad in adinputs:
            log.stdinfo(f"Resampling {ad.filename}")
            ad_out = transform.resample_from_wcs(ad, frame, order=order, conserve=conserve,
                                                 output_shape=output_shape, origin=origin,
                                                 process_objcat=process_objcat)
            adoutputs.append(ad_out)

        return adoutputs
예제 #2
0
    def _resample_to_new_frame(self, adinputs=None, frame=None, order=3,
                               trim_data=False, clean_data=False,
                               process_objcat=False):
        """
        This private method resamples a number of AstroData objects to a
        CoordinateFrame they share. It is basically just a wrapper for the
        transform.resample_from_wcs() method that creates an
        appropriately-sized output based on the complete set of input
        AstroData objects.

        Parameters
        ----------
        frame: str
            name of CoordinateFrame to be resampled to
        order: int (0-5)
            order of interpolation (0=nearest, 1=linear, etc.)
        trim_data: bool
            trim image to size of reference image?
        clean_data: bool
            replace bad pixels with a ring median of their values to avoid
            ringing if using a high-order interpolation?
        process_objcat: bool
            update (rather than delete) the OBJCAT?
        """
        log = self.log

        if clean_data:
            self.applyDQPlane(adinputs, replace_flags=DQ.not_signal ^ DQ.no_data,
                              replace_value="median", inner=3, outer=5)

        if trim_data:
            output_shape = adinputs[0][0].shape
            origin = (0,) * len(output_shape)
        else:
            all_corners = np.concatenate([transform.get_output_corners(
                ad[0].wcs.get_transform(ad[0].wcs.input_frame, frame),
                input_shape=ad[0].shape) for ad in adinputs], axis=1)
            origin = tuple(np.ceil(min(corners)) for corners in all_corners)
            output_shape = tuple(int(np.floor(max(corners)) - np.ceil(min(corners)) + 1)
                                 for corners in all_corners)

        print("ORIGIN", origin)
        log.stdinfo("Output image will have shape "+repr(output_shape[::-1]))
        adoutputs = []
        for ad in adinputs:
            log.stdinfo(f"Resampling {ad.filename}")
            ad_out = transform.resample_from_wcs(ad, frame, order=order,
                                                 output_shape=output_shape, origin=origin,
                                                 process_objcat=process_objcat)
            adoutputs.append(ad_out)

        return adoutputs
def _split_mosaic_into_extensions(ref_ad, mos_ad, border_size=0):
    """
    Split the `mos_ad` mosaicked data into multiple extensions using
    coordinate frames and transformations stored in the `ref_ad` object.

    Right now, the pixels at the border of each extensions might not
    match the expected values. The mosaicking and de-mosaicking is an
    interpolation, because there's a small rotation. This will only interpolate,
    not extrapolate beyond the boundaries of the input data, so you lose some
    information at the edges when you perform both operations and consequently
    the edges of the input frame get lost.

    Parameters
    ----------
    ref_ad : AstroData
        Reference multi-extension-file object containing a gWCS.
    mos_ad : AstroData
        Mosaicked data that will be split containing a single extension.
    border_size : int
        Number of pixels to be trimmed out from each border.

    Returns
    -------
    AstroData : Split multi-extension-file object.

    See Also
    --------
    - :func:`gempy.library.transform.add_mosaic_wcs`
    - :func:`gempy.library.transform.resample_from_wcs`
    """
    # Check input data
    if len(mos_ad) > 1:
        raise ValueError("Expected number of extensions of `mos_ad` to be 1. "
                         "Found {:d}".format(len(mos_ad)))

    if len(mos_ad[0].shape) != 2:
        raise ValueError("Expected ndim for `mos_ad` to be 2. "
                         "Found {:d}".format(len(mos_ad[0].shape)))

    # Get original relative shift
    origin_shift_y, origin_shift_x = mos_ad[0].nddata.meta['transform'][
        'origin']

    # Create shift transformation
    shift_x = models.Shift(origin_shift_x - border_size)
    shift_y = models.Shift(origin_shift_y - border_size)

    # Create empty AD
    ad_out = astrodata.create(ref_ad.phu)

    # Update data_section to be able to resample WCS frames
    datasec_kw = mos_ad._keyword_for('data_section')
    mos_ad[0].hdr[datasec_kw] = '[1:{},1:{}]'.format(*mos_ad[0].shape[::-1])

    # Loop across all extensions
    for i, ref_ext in enumerate(ref_ad):

        # Create new transformation pipeline
        in_frame = ref_ext.wcs.input_frame
        mos_frame = coordinate_frames.Frame2D(name="mosaic")

        mosaic_to_pixel = ref_ext.wcs.get_transform(mos_frame, in_frame)

        pipeline = [(mos_frame, mosaic_to_pixel), (in_frame, None)]

        mos_ad[0].wcs = gWCS(pipeline)

        # Shift mosaic in order to set reference (0, 0) on Detector 2
        mos_ad[0].wcs.insert_transform(mos_frame,
                                       shift_x & shift_y,
                                       after=True)

        # Apply transformation
        temp_ad = transform.resample_from_wcs(mos_ad,
                                              in_frame.name,
                                              origin=(0, 0),
                                              output_shape=ref_ext.shape)

        # Update data_section
        datasec_kw = ref_ad._keyword_for('data_section')
        temp_ad[0].hdr[datasec_kw] = \
            '[1:{:d},1:{:d}]'.format(*temp_ad[0].shape[::-1])

        # If detector_section returned something, set an appropriate value
        det_sec_kw = ref_ext._keyword_for('detector_section')
        det_sec = ref_ext.detector_section()

        if det_sec:
            temp_ad[0].hdr[det_sec_kw] = \
                '[{}:{},{}:{}]'.format(
                    det_sec.x1 + 1, det_sec.x2, det_sec.y1 + 1, det_sec.y2)
        else:
            del temp_ad[0].hdr[det_sec_kw]

        # If array_section returned something, set an appropriate value
        arr_sec_kw = ref_ext._keyword_for('array_section')
        arr_sec = ref_ext.array_section()

        if arr_sec:
            temp_ad[0].hdr[arr_sec_kw] = \
                '[{}:{},{}:{}]'.format(
                    arr_sec.x1 + 1, arr_sec.x2, arr_sec.y1 + 1, arr_sec.y2)
        else:
            del temp_ad[0].hdr[arr_sec_kw]

        ad_out.append(temp_ad[0])

    return ad_out
    def makeSlitIllum(self, adinputs=None, **params):
        """
        Makes the processed Slit Illumination Function by binning a 2D
        spectrum along the dispersion direction, fitting a smooth function
        for each bin, fitting a smooth 2D model, and reconstructing the 2D
        array using this last model.

        Its implementation based on the IRAF's `noao.twodspec.longslit.illumination`
        task following the algorithm described in [Valdes, 1968].

        It expects an input calibration image to be an a dispersed image of the
        slit without illumination problems (e.g, twilight flat). The spectra is
        not required to be smooth in wavelength and may contain strong emission
        and absorption lines. The image should contain a `.mask` attribute in
        each extension, and it is expected to be overscan and bias corrected.

        Parameters
        ----------
        adinputs : list
            List of AstroData objects containing the dispersed image of the
            slit of a source free of illumination problems. The data needs to
            have been overscan and bias corrected and is expected to have a
            Data Quality mask.
        bins : {None, int}, optional
            Total number of bins across the dispersion axis. If None,
            the number of bins will match the number of extensions on each
            input AstroData object. It it is an int, it will create N bins
            with the same size.
        border : int, optional
            Border size that is added on every edge of the slit illumination
            image before cutting it down to the input AstroData frame.
        smooth_order : int, optional
            Order of the spline that is used in each bin fitting to smooth
            the data (Default: 3)
        x_order : int, optional
            Order of the x-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.
        y_order : int, optional
            Order of the y-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.

        Return
        ------
        List of AstroData : containing an AstroData with the Slit Illumination
            Response Function for each of the input object.

        References
        ----------
        .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With
           IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI,
           (13 October 1986); https://doi.org/10.1117/12.968155
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params["suffix"]
        bins = params["bins"]
        border = params["border"]
        debug_plot = params["debug_plot"]
        smooth_order = params["smooth_order"]
        cheb2d_x_order = params["x_order"]
        cheb2d_y_order = params["y_order"]

        ad_outputs = []
        for ad in adinputs:

            if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames:

                log.info('Add "mosaic" gWCS frame to input data')
                geotable = import_module('.geometry_conf', self.inst_lookups)

                # deepcopy prevents modifying input `ad` inplace
                ad = transform.add_mosaic_wcs(deepcopy(ad), geotable)

                log.info("Temporarily mosaicking multi-extension file")
                mosaicked_ad = transform.resample_from_wcs(
                    ad,
                    "mosaic",
                    attributes=None,
                    order=1,
                    process_objcat=False)

            else:

                log.info('Input data already has one extension and has a '
                         '"mosaic" frame.')

                # deepcopy prevents modifying input `ad` inplace
                mosaicked_ad = deepcopy(ad)

            log.info("Transposing data if needed")
            dispaxis = 2 - mosaicked_ad[0].dispersion_axis()  # python sense
            should_transpose = dispaxis == 1

            data, mask, variance = _transpose_if_needed(
                mosaicked_ad[0].data,
                mosaicked_ad[0].mask,
                mosaicked_ad[0].variance,
                transpose=should_transpose)

            log.info("Masking data")
            data = np.ma.masked_array(data, mask=mask)
            variance = np.ma.masked_array(variance, mask=mask)
            std = np.sqrt(variance)  # Easier to work with

            log.info("Creating bins for data and variance")
            height = data.shape[0]
            width = data.shape[1]

            if bins is None:
                nbins = max(len(ad), 12)
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            elif isinstance(bins, int):
                nbins = bins
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            else:
                # ToDo: Handle input bins as array
                raise TypeError("Expected None or Int for `bins`. "
                                "Found: {}".format(type(bins)))

            bin_top = bin_limits[1:]
            bin_bot = bin_limits[:-1]
            binned_data = np.zeros_like(data)
            binned_std = np.zeros_like(std)

            log.info("Smooth binned data and variance, and normalize them by "
                     "smoothed central value")
            for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)):

                rows = np.arange(width)

                avg_data = np.ma.mean(data[b0:b1], axis=0)
                model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_data, order=smooth_order)

                avg_std = np.ma.mean(std[b0:b1], axis=0)
                model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_std, order=smooth_order)

                slit_central_value = model_1d_data(rows)[width // 2]
                binned_data[b0:b1] = model_1d_data(rows) / slit_central_value
                binned_std[b0:b1] = model_1d_std(rows) / slit_central_value

            log.info("Reconstruct 2D mosaicked data")
            bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int)
            cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center)

            fitter = fitting.SLSQPLSQFitter()
            model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order,
                                               x_domain=(0, width),
                                               y_degree=cheb2d_y_order,
                                               y_domain=(0, height))

            model_2d_data = fitter(model_2d_init, cols_fit, rows_fit,
                                   binned_data[rows_fit, cols_fit])

            model_2d_std = fitter(model_2d_init, cols_fit, rows_fit,
                                  binned_std[rows_fit, cols_fit])

            rows_val, cols_val = \
                np.mgrid[-border:height+border, -border:width+border]

            slit_response_data = model_2d_data(cols_val, rows_val)
            slit_response_mask = np.pad(
                mask, border, mode='edge')  # ToDo: any update to the mask?
            slit_response_std = model_2d_std(cols_val, rows_val)
            slit_response_var = slit_response_std**2

            del cols_fit, cols_val, rows_fit, rows_val

            _data, _mask, _variance = _transpose_if_needed(
                slit_response_data,
                slit_response_mask,
                slit_response_var,
                transpose=dispaxis == 1)

            log.info("Update slit response data and data_section")
            slit_response_ad = deepcopy(mosaicked_ad)
            slit_response_ad[0].data = _data
            slit_response_ad[0].mask = _mask
            slit_response_ad[0].variance = _variance

            if "mosaic" in ad[0].wcs.available_frames:

                log.info(
                    "Map coordinates between slit function and mosaicked data"
                )  # ToDo: Improve message?
                slit_response_ad = _split_mosaic_into_extensions(
                    ad, slit_response_ad, border_size=border)

            elif len(ad) == 1:

                log.info("Trim out borders")

                slit_response_ad[0].data = \
                    slit_response_ad[0].data[border:-border, border:-border]
                slit_response_ad[0].mask = \
                    slit_response_ad[0].mask[border:-border, border:-border]
                slit_response_ad[0].variance = \
                    slit_response_ad[0].variance[border:-border, border:-border]

            log.info("Update metadata and filename")
            gt.mark_history(slit_response_ad,
                            primname=self.myself(),
                            keyword=timestamp_key)

            slit_response_ad.update_filename(suffix=suffix, strip=True)
            ad_outputs.append(slit_response_ad)

            # Plotting ------
            if debug_plot:

                log.info("Creating plots")
                palette = copy(plt.cm.cividis)
                palette.set_bad('r', 0.75)

                norm = vis.ImageNormalize(data[~data.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                fig = plt.figure(num="Slit Response from MEF - {}".format(
                    ad.filename),
                                 figsize=(12, 9),
                                 dpi=110)

                gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

                # Display raw mosaicked data and its bins ---
                ax1 = fig.add_subplot(gs[0, 0])
                im1 = ax1.imshow(data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax1.set_title("Mosaicked Data\n and Spectral Bins",
                              fontsize=10)
                ax1.set_xlim(-1, data.shape[1])
                ax1.set_xticks([])
                ax1.set_ylim(-1, data.shape[0])
                ax1.set_yticks(bin_center)
                ax1.tick_params(axis=u'both', which=u'both', length=0)

                ax1.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax1.spines[s].set_visible(False) for s in ax1.spines]
                _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax1)
                cax1 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im1, cax=cax1)

                # Display non-smoothed bins ---
                ax2 = fig.add_subplot(gs[0, 1])
                im2 = ax2.imshow(binned_data, cmap=palette, origin='lower')

                ax2.set_title("Binned, smoothed\n and normalized data ",
                              fontsize=10)
                ax2.set_xlim(0, data.shape[1])
                ax2.set_xticks([])
                ax2.set_ylim(0, data.shape[0])
                ax2.set_yticks(bin_center)
                ax2.tick_params(axis=u'both', which=u'both', length=0)

                ax2.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax2.spines[s].set_visible(False) for s in ax2.spines]
                _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax2)
                cax2 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im2, cax=cax2)

                # Display reconstructed slit response ---
                vmin = slit_response_data.min()
                vmax = slit_response_data.max()

                ax3 = fig.add_subplot(gs[1, 0])
                im3 = ax3.imshow(slit_response_data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=vmin,
                                 vmax=vmax)

                ax3.set_title("Reconstructed\n Slit response", fontsize=10)
                ax3.set_xlim(0, data.shape[1])
                ax3.set_xticks([])
                ax3.set_ylim(0, data.shape[0])
                ax3.set_yticks([])
                ax3.tick_params(axis=u'both', which=u'both', length=0)
                _ = [ax3.spines[s].set_visible(False) for s in ax3.spines]

                divider = make_axes_locatable(ax3)
                cax3 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im3, cax=cax3)

                # Display extensions ---
                ax4 = fig.add_subplot(gs[1, 1])
                ax4.set_xticks([])
                ax4.set_yticks([])
                _ = [ax4.spines[s].set_visible(False) for s in ax4.spines]

                sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad),
                                                           ncols=1,
                                                           subplot_spec=gs[1,
                                                                           1],
                                                           hspace=0.03)

                # The [::-1] is needed to put the fist extension in the bottom
                for i, ext in enumerate(slit_response_ad[::-1]):

                    ext_data, ext_mask, ext_variance = _transpose_if_needed(
                        ext.data,
                        ext.mask,
                        ext.variance,
                        transpose=dispaxis == 1)

                    ext_data = np.ma.masked_array(ext_data, mask=ext_mask)

                    sub_ax = fig.add_subplot(sub_gs4[i])

                    im4 = sub_ax.imshow(ext_data,
                                        origin="lower",
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap=palette)

                    sub_ax.set_xlim(0, ext_data.shape[1])
                    sub_ax.set_xticks([])
                    sub_ax.set_ylim(0, ext_data.shape[0])
                    sub_ax.set_yticks([ext_data.shape[0] // 2])

                    sub_ax.set_yticklabels(
                        ["Ext {}".format(len(slit_response_ad) - i - 1)],
                        fontsize=6)

                    _ = [
                        sub_ax.spines[s].set_visible(False)
                        for s in sub_ax.spines
                    ]

                    if i == 0:
                        sub_ax.set_title(
                            "Multi-extension\n Slit Response Function")

                divider = make_axes_locatable(ax4)
                cax4 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im4, cax=cax4)

                # Display Signal-To-Noise Ratio ---
                snr = data / np.sqrt(variance)

                norm = vis.ImageNormalize(snr[~snr.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                ax5 = fig.add_subplot(gs[0, 2])

                im5 = ax5.imshow(snr,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax5.set_title("Mosaicked Data SNR", fontsize=10)
                ax5.set_xlim(-1, data.shape[1])
                ax5.set_xticks([])
                ax5.set_ylim(-1, data.shape[0])
                ax5.set_yticks(bin_center)
                ax5.tick_params(axis=u'both', which=u'both', length=0)

                ax5.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax5.spines[s].set_visible(False) for s in ax5.spines]
                _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax5)
                cax5 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im5, cax=cax5)

                # Display Signal-To-Noise Ratio of Slit Illumination ---
                slit_response_snr = np.ma.masked_array(
                    slit_response_data / np.sqrt(slit_response_var),
                    mask=slit_response_mask)

                ax6 = fig.add_subplot(gs[1, 2])

                im6 = ax6.imshow(slit_response_snr,
                                 origin="lower",
                                 vmin=norm.vmin,
                                 vmax=norm.vmax,
                                 cmap=palette)

                ax6.set_xlim(0, slit_response_snr.shape[1])
                ax6.set_xticks([])
                ax6.set_ylim(0, slit_response_snr.shape[0])
                ax6.set_yticks([])
                ax6.set_title("Reconstructed\n Slit Response SNR")

                _ = [ax6.spines[s].set_visible(False) for s in ax6.spines]

                divider = make_axes_locatable(ax6)
                cax6 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im6, cax=cax6)

                # Save plots ---
                fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5)
                fname = slit_response_ad.filename.replace(".fits", ".png")
                log.info("Saving plots to {}".format(fname))
                plt.savefig(fname)

        return ad_outputs
예제 #5
0
def test_split_mosaic_into_extensions(request):
    """
    Tests helper function that split a mosaicked data into several extensions
    based on another multi-extension file that contains gWCS.
    """
    astrofaker = pytest.importorskip("astrofaker")

    ad = astrofaker.create('GMOS-S')
    ad.init_default_extensions(binning=2)

    ad = transform.add_mosaic_wcs(ad, geotable)
    ad = gt.trim_to_data_section(
        ad, keyword_comments={'NAXIS1': "", 'NAXIS2': "", 'DATASEC': "",
                              'TRIMSEC': "", 'CRPIX1': "", 'CRPIX2': ""})

    for i, ext in enumerate(ad):
        x1 = ext.detector_section().x1
        x2 = ext.detector_section().x2
        xb = ext.detector_x_bin()

        data = np.arange(x1 // xb, x2 // xb)[np.newaxis, :]
        data = np.repeat(data, ext.data.shape[0], axis=0)
        data = data + 0.1 * (0.5 - np.random.random(data.shape))

        ext.data = data

    mosaic_ad = transform.resample_from_wcs(
        ad, "mosaic", attributes=None, order=1, process_objcat=False)

    mosaic_ad[0].data = np.pad(mosaic_ad[0].data, 10, mode='edge')

    mosaic_ad[0].hdr[mosaic_ad._keyword_for('data_section')] = \
        '[1:{},1:{}]'.format(*mosaic_ad[0].shape[::-1])

    ad2 = primitives_gmos_longslit._split_mosaic_into_extensions(
        ad, mosaic_ad, border_size=10)

    if request.config.getoption("--do-plots"):

        palette = copy(plt.cm.viridis)
        palette.set_bad('r', 1)

        fig = plt.figure(num="Test: Split Mosaic Into Extensions", figsize=(8, 6.5), dpi=120)
        fig.suptitle("Test Split Mosaic Into Extensions\n Difference between"
                     " input and mosaicked/demosaicked data")

        gs = fig.add_gridspec(nrows=4, ncols=len(ad) // 3, wspace=0.1, height_ratios=[1, 1, 1, 0.1])

        for i, (ext, ext2) in enumerate(zip(ad, ad2)):

            data1 = ext.data
            data2 = ext2.data
            diff = np.ma.masked_array(data1 - data2, mask=np.abs(data1 - data2) > 1)
            height, width = data1.shape

            row = i // 4
            col = i % 4

            ax = fig.add_subplot(gs[row, col])
            ax.set_title("Ext {}".format(i + 1))
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_yticks([])
            ax.set_yticklabels([])
            _ = [ax.spines[s].set_visible(False) for s in ax.spines]

            if col == 0:
                ax.set_ylabel("Det {}".format(row + 1))

            sub_gs = gridspec.GridSpecFromSubplotSpec(2, 2, ax, wspace=0.05, hspace=0.05)

            for j in range(4):
                sx = fig.add_subplot(sub_gs[j])
                im = sx.imshow(diff, origin='lower', cmap=palette, vmin=-0.1, vmax=0.1)

                sx.set_xticks([])
                sx.set_yticks([])
                sx.set_xticklabels([])
                sx.set_yticklabels([])
                _ = [sx.spines[s].set_visible(False) for s in sx.spines]

                if j == 0:
                    sx.set_xlim(0, 25)
                    sx.set_ylim(height - 25, height)

                if j == 1:
                    sx.set_xlim(width - 25, width)
                    sx.set_ylim(height - 25, height)

                if j == 2:
                    sx.set_xlim(0, 25)
                    sx.set_ylim(0, 25)

                if j == 3:
                    sx.set_xlim(width - 25, width)
                    sx.set_ylim(0, 25)

        cax = fig.add_subplot(gs[3, :])
        cbar = plt.colorbar(im, cax=cax, orientation="horizontal")
        cbar.set_label("Difference levels")

        os.makedirs(PLOT_PATH, exist_ok=True)

        fig.savefig(
            os.path.join(PLOT_PATH, "test_split_mosaic_into_extensions.png"))

    # Actual test ----
    for i, (ext, ext2) in enumerate(zip(ad, ad2)):
        data1 = np.ma.masked_array(ext.data[1:-1, 1:-1], mask=ext.mask)
        data2 = np.ma.masked_array(ext2.data[1:-1, 1:-1], mask=ext2.mask)

        np.testing.assert_almost_equal(data1, data2, decimal=1)
예제 #6
0
    def tileArrays(self, adinputs=None, **params):
        """
        This primitive combines extensions by tiling (no interpolation).
        The array_section() and detector_section() descriptors are used
        to derive the geometry of the tiling, so outside help (from the
        instrument's geometry_conf module) is only required if there are
        multiple arrays being tiled together, as the gaps need to be
        specified.

        If the input AstroData objects still have non-data regions, these
        will not be trimmed. However, the WCS of the final image will
        only be correct for some of the image since extra space has been
        introduced into the image.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        tile_all: bool
            tile to a single extension, rather than one per array?
            (array=physical detector)
        sci_only: bool
            tile only the data plane?
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params['suffix']
        tile_all = params['tile_all']
        attributes = ['data'] if params["sci_only"] else None

        adoutputs = []
        for ad in adinputs:
            if len(ad) == 1:
                log.warning("{} has only one extension, so there's nothing "
                            "to tile".format(ad.filename))
                adoutputs.append(ad)
                continue

            array_info = gt.array_information(ad)
            detshape = array_info.detector_shape
            if not tile_all and set(array_info.array_shapes) == {(1, 1)}:
                log.warning("{} has nothing to tile, as tile_all=False but "
                            "each array has only one amplifier.")
                adoutputs.append(ad)
                continue

            if tile_all and detshape != (1, 1):  # We need gaps!
                geotable = import_module('.geometry_conf', self.inst_lookups)
                chip_gaps = geotable.tile_gaps[ad.detector_name()]
                try:
                    xgap, ygap = chip_gaps
                except TypeError:  # single number, applies to both
                    xgap = ygap = chip_gaps

            kw = ad._keyword_for('data_section')
            xbin, ybin = ad.detector_x_bin(), ad.detector_y_bin()

            # Work out additional shifts required to cope with posisble overscan
            # regions, including those in already-tiled CCDs
            if tile_all:
                yorigins, xorigins = np.rollaxis(
                    np.array(array_info.origins),
                    1).reshape((2, ) + array_info.detector_shape)
                xorigins //= xbin
                yorigins //= ybin
            else:
                yorigins, xorigins = np.zeros((2, ) +
                                              array_info.detector_shape)
            it_ccd = np.nditer(xorigins, flags=['multi_index'])
            i = 0
            while not it_ccd.finished:
                ccdy, ccdx = it_ccd.multi_index
                shp = array_info.array_shapes[i]
                exts = array_info.extensions[i]
                xshifts = np.zeros(shp, dtype=np.int32)
                yshifts = np.zeros(shp, dtype=np.int32)
                it = np.nditer(np.array(exts).reshape(shp),
                               flags=['multi_index'])
                while not it.finished:
                    iy, ix = it.multi_index
                    ext = ad[int(it[0])]
                    datsec = ext.data_section()
                    if datsec.x1 > 0:
                        xshifts[iy, ix:] += datsec.x1
                    if datsec.x2 < ext.shape[1]:
                        xshifts[iy, ix + 1:] += ext.shape[1] - datsec.x2
                    if datsec.y1 > 0:
                        yshifts[iy:, ix] += datsec.y1
                    if datsec.y2 < ext.shape[0]:
                        xshifts[iy + 1:, ix] += ext.shape[0] - datsec.y2

                    arrsec = ext.array_section()
                    ext_shift = (models.Shift(
                        (arrsec.x1 // xbin - datsec.x1)) & models.Shift(
                            (arrsec.y1 // ybin - datsec.y1)))

                    # We need to have a "tile" Frame to resample to.
                    # We also need to perform the inverse, after the "tile"
                    # frame, of any change we make beforehand.
                    if ext.wcs is None:
                        ext.wcs = gWCS([(Frame2D(name="pixels"), ext_shift),
                                        (Frame2D(name="tile"), None)])
                    elif 'tile' not in ext.wcs.available_frames:
                        #ext.wcs.insert_frame(ext.wcs.input_frame, ext_shift,
                        #                     Frame2D(name="tile"))
                        ext.wcs = gWCS([(ext.wcs.input_frame, ext_shift),
                                        (Frame2D(name="tile"),
                                         ext.wcs.pipeline[0].transform)] +
                                       ext.wcs.pipeline[1:])
                        ext.wcs.insert_transform('tile',
                                                 ext_shift.inverse,
                                                 after=True)

                    dx, dy = xshifts[iy, ix], yshifts[iy, ix]
                    if tile_all:
                        dx += xorigins[ccdy, ccdx]
                        dy += yorigins[ccdy, ccdx]
                    if dx or dy:  # Don't bother if they're both zero
                        shift_model = models.Shift(dx) & models.Shift(dy)
                        ext.wcs.insert_transform('tile',
                                                 shift_model,
                                                 after=False)
                        if ext.wcs.output_frame.name != 'tile':
                            ext.wcs.insert_transform('tile',
                                                     shift_model.inverse,
                                                     after=True)

                    # Reset data_section since we're not trimming overscans
                    ext.hdr[kw] = '[1:{},1:{}]'.format(*reversed(ext.shape))
                    it.iternext()

                if tile_all:
                    # We need to shift other arrays if this one is larger than
                    # its expected size due to overscan regions. We've kept
                    # track of shifts we've introduced, but it might also be
                    # the case that we've been sent a previous tile_all=False output
                    if ccdx < detshape[1] - 1:
                        max_xshift = max(
                            xshifts.max(), ext.shape[1] -
                            (xorigins[ccdy, ccdx + 1] - xorigins[ccdy, ccdx]))
                        xorigins[ccdy, ccdx + 1:] += max_xshift + xgap // xbin
                    if ccdy < detshape[0] - 1:
                        max_yshift = max(
                            yshifts.max(), ext.shape[0] -
                            (yorigins[ccdy + 1, ccdx] - yorigins[ccdy, ccdx]))
                        yorigins[ccdy + 1:, ccdx] += max_yshift + ygap // ybin
                elif i == 0:
                    ad_out = transform.resample_from_wcs(ad[exts],
                                                         "tile",
                                                         attributes=attributes,
                                                         process_objcat=True)
                else:
                    ad_out.append(
                        transform.resample_from_wcs(ad[exts],
                                                    "tile",
                                                    attributes=attributes,
                                                    process_objcat=True)[0])
                i += 1
                it_ccd.iternext()

            if tile_all:
                ad_out = transform.resample_from_wcs(ad,
                                                     "tile",
                                                     attributes=attributes,
                                                     process_objcat=True)

            gt.mark_history(ad_out,
                            primname=self.myself(),
                            keyword=timestamp_key)
            ad_out.orig_filename = ad.filename
            ad_out.update_filename(suffix=suffix, strip=True)
            adoutputs.append(ad_out)
        return adoutputs
예제 #7
0
    def mosaicDetectors(self, adinputs=None, **params):
        """
        This primitive does a full mosaic of all the arrays in an AD object.
        An appropriate geometry_conf.py module containing geometric information
        is required.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files.
        sci_only: bool
            mosaic only SCI image data. Default is False
        order: int (1-5)
            order of spline interpolation
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params['suffix']
        order = params['order']
        attributes = ['data'] if params['sci_only'] else None
        geotable = import_module('.geometry_conf', self.inst_lookups)

        adoutputs = []
        for ad in adinputs:
            if ad.phu.get(timestamp_key):
                log.warning("No changes will be made to {}, since it has "
                            "already been processed by mosaicDetectors".format(
                                ad.filename))
                adoutputs.append(ad)
                continue

            if len(ad) == 1:
                log.warning("{} has only one extension, so there's nothing "
                            "to mosaic".format(ad.filename))
                adoutputs.append(ad)
                continue

            if not all(
                    np.issubdtype(ext.data.dtype, np.floating) for ext in ad):
                log.warning("Cannot mosaic {} with non-floating point data. "
                            "Use tileArrays instead".format(ad.filename))
                adoutputs.append(ad)
                continue

            transform.add_mosaic_wcs(ad, geotable)

            # If there's an overscan section in the data, this will crash, but
            # we can catch that, trim, and try again. Don't catch anything else
            try:
                ad_out = transform.resample_from_wcs(ad,
                                                     "mosaic",
                                                     attributes=attributes,
                                                     order=order,
                                                     process_objcat=False)
            except ValueError as e:
                if 'data sections' in repr(e):
                    ad = gt.trim_to_data_section(ad, self.keyword_comments)
                    ad_out = transform.resample_from_wcs(ad,
                                                         "mosaic",
                                                         attributes=attributes,
                                                         order=order,
                                                         process_objcat=False)
                else:
                    raise e

            ad_out.orig_filename = ad.filename
            gt.mark_history(ad_out,
                            primname=self.myself(),
                            keyword=timestamp_key)
            ad_out.update_filename(suffix=suffix, strip=True)
            adoutputs.append(ad_out)

        return adoutputs