def makeSlitIllum(self, adinputs=None, **params):
        """
        Makes the processed Slit Illumination Function by binning a 2D
        spectrum along the dispersion direction, fitting a smooth function
        for each bin, fitting a smooth 2D model, and reconstructing the 2D
        array using this last model.

        Its implementation based on the IRAF's `noao.twodspec.longslit.illumination`
        task following the algorithm described in [Valdes, 1968].

        It expects an input calibration image to be an a dispersed image of the
        slit without illumination problems (e.g, twilight flat). The spectra is
        not required to be smooth in wavelength and may contain strong emission
        and absorption lines. The image should contain a `.mask` attribute in
        each extension, and it is expected to be overscan and bias corrected.

        Parameters
        ----------
        adinputs : list
            List of AstroData objects containing the dispersed image of the
            slit of a source free of illumination problems. The data needs to
            have been overscan and bias corrected and is expected to have a
            Data Quality mask.
        bins : {None, int}, optional
            Total number of bins across the dispersion axis. If None,
            the number of bins will match the number of extensions on each
            input AstroData object. It it is an int, it will create N bins
            with the same size.
        border : int, optional
            Border size that is added on every edge of the slit illumination
            image before cutting it down to the input AstroData frame.
        smooth_order : int, optional
            Order of the spline that is used in each bin fitting to smooth
            the data (Default: 3)
        x_order : int, optional
            Order of the x-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.
        y_order : int, optional
            Order of the y-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.

        Return
        ------
        List of AstroData : containing an AstroData with the Slit Illumination
            Response Function for each of the input object.

        References
        ----------
        .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With
           IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI,
           (13 October 1986); https://doi.org/10.1117/12.968155
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params["suffix"]
        bins = params["bins"]
        border = params["border"]
        debug_plot = params["debug_plot"]
        smooth_order = params["smooth_order"]
        cheb2d_x_order = params["x_order"]
        cheb2d_y_order = params["y_order"]

        ad_outputs = []
        for ad in adinputs:

            if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames:

                log.info('Add "mosaic" gWCS frame to input data')
                geotable = import_module('.geometry_conf', self.inst_lookups)

                # deepcopy prevents modifying input `ad` inplace
                ad = transform.add_mosaic_wcs(deepcopy(ad), geotable)

                log.info("Temporarily mosaicking multi-extension file")
                mosaicked_ad = transform.resample_from_wcs(
                    ad,
                    "mosaic",
                    attributes=None,
                    order=1,
                    process_objcat=False)

            else:

                log.info('Input data already has one extension and has a '
                         '"mosaic" frame.')

                # deepcopy prevents modifying input `ad` inplace
                mosaicked_ad = deepcopy(ad)

            log.info("Transposing data if needed")
            dispaxis = 2 - mosaicked_ad[0].dispersion_axis()  # python sense
            should_transpose = dispaxis == 1

            data, mask, variance = _transpose_if_needed(
                mosaicked_ad[0].data,
                mosaicked_ad[0].mask,
                mosaicked_ad[0].variance,
                transpose=should_transpose)

            log.info("Masking data")
            data = np.ma.masked_array(data, mask=mask)
            variance = np.ma.masked_array(variance, mask=mask)
            std = np.sqrt(variance)  # Easier to work with

            log.info("Creating bins for data and variance")
            height = data.shape[0]
            width = data.shape[1]

            if bins is None:
                nbins = max(len(ad), 12)
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            elif isinstance(bins, int):
                nbins = bins
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            else:
                # ToDo: Handle input bins as array
                raise TypeError("Expected None or Int for `bins`. "
                                "Found: {}".format(type(bins)))

            bin_top = bin_limits[1:]
            bin_bot = bin_limits[:-1]
            binned_data = np.zeros_like(data)
            binned_std = np.zeros_like(std)

            log.info("Smooth binned data and variance, and normalize them by "
                     "smoothed central value")
            for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)):

                rows = np.arange(width)

                avg_data = np.ma.mean(data[b0:b1], axis=0)
                model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_data, order=smooth_order)

                avg_std = np.ma.mean(std[b0:b1], axis=0)
                model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_std, order=smooth_order)

                slit_central_value = model_1d_data(rows)[width // 2]
                binned_data[b0:b1] = model_1d_data(rows) / slit_central_value
                binned_std[b0:b1] = model_1d_std(rows) / slit_central_value

            log.info("Reconstruct 2D mosaicked data")
            bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int)
            cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center)

            fitter = fitting.SLSQPLSQFitter()
            model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order,
                                               x_domain=(0, width),
                                               y_degree=cheb2d_y_order,
                                               y_domain=(0, height))

            model_2d_data = fitter(model_2d_init, cols_fit, rows_fit,
                                   binned_data[rows_fit, cols_fit])

            model_2d_std = fitter(model_2d_init, cols_fit, rows_fit,
                                  binned_std[rows_fit, cols_fit])

            rows_val, cols_val = \
                np.mgrid[-border:height+border, -border:width+border]

            slit_response_data = model_2d_data(cols_val, rows_val)
            slit_response_mask = np.pad(
                mask, border, mode='edge')  # ToDo: any update to the mask?
            slit_response_std = model_2d_std(cols_val, rows_val)
            slit_response_var = slit_response_std**2

            del cols_fit, cols_val, rows_fit, rows_val

            _data, _mask, _variance = _transpose_if_needed(
                slit_response_data,
                slit_response_mask,
                slit_response_var,
                transpose=dispaxis == 1)

            log.info("Update slit response data and data_section")
            slit_response_ad = deepcopy(mosaicked_ad)
            slit_response_ad[0].data = _data
            slit_response_ad[0].mask = _mask
            slit_response_ad[0].variance = _variance

            if "mosaic" in ad[0].wcs.available_frames:

                log.info(
                    "Map coordinates between slit function and mosaicked data"
                )  # ToDo: Improve message?
                slit_response_ad = _split_mosaic_into_extensions(
                    ad, slit_response_ad, border_size=border)

            elif len(ad) == 1:

                log.info("Trim out borders")

                slit_response_ad[0].data = \
                    slit_response_ad[0].data[border:-border, border:-border]
                slit_response_ad[0].mask = \
                    slit_response_ad[0].mask[border:-border, border:-border]
                slit_response_ad[0].variance = \
                    slit_response_ad[0].variance[border:-border, border:-border]

            log.info("Update metadata and filename")
            gt.mark_history(slit_response_ad,
                            primname=self.myself(),
                            keyword=timestamp_key)

            slit_response_ad.update_filename(suffix=suffix, strip=True)
            ad_outputs.append(slit_response_ad)

            # Plotting ------
            if debug_plot:

                log.info("Creating plots")
                palette = copy(plt.cm.cividis)
                palette.set_bad('r', 0.75)

                norm = vis.ImageNormalize(data[~data.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                fig = plt.figure(num="Slit Response from MEF - {}".format(
                    ad.filename),
                                 figsize=(12, 9),
                                 dpi=110)

                gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

                # Display raw mosaicked data and its bins ---
                ax1 = fig.add_subplot(gs[0, 0])
                im1 = ax1.imshow(data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax1.set_title("Mosaicked Data\n and Spectral Bins",
                              fontsize=10)
                ax1.set_xlim(-1, data.shape[1])
                ax1.set_xticks([])
                ax1.set_ylim(-1, data.shape[0])
                ax1.set_yticks(bin_center)
                ax1.tick_params(axis=u'both', which=u'both', length=0)

                ax1.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax1.spines[s].set_visible(False) for s in ax1.spines]
                _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax1)
                cax1 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im1, cax=cax1)

                # Display non-smoothed bins ---
                ax2 = fig.add_subplot(gs[0, 1])
                im2 = ax2.imshow(binned_data, cmap=palette, origin='lower')

                ax2.set_title("Binned, smoothed\n and normalized data ",
                              fontsize=10)
                ax2.set_xlim(0, data.shape[1])
                ax2.set_xticks([])
                ax2.set_ylim(0, data.shape[0])
                ax2.set_yticks(bin_center)
                ax2.tick_params(axis=u'both', which=u'both', length=0)

                ax2.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax2.spines[s].set_visible(False) for s in ax2.spines]
                _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax2)
                cax2 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im2, cax=cax2)

                # Display reconstructed slit response ---
                vmin = slit_response_data.min()
                vmax = slit_response_data.max()

                ax3 = fig.add_subplot(gs[1, 0])
                im3 = ax3.imshow(slit_response_data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=vmin,
                                 vmax=vmax)

                ax3.set_title("Reconstructed\n Slit response", fontsize=10)
                ax3.set_xlim(0, data.shape[1])
                ax3.set_xticks([])
                ax3.set_ylim(0, data.shape[0])
                ax3.set_yticks([])
                ax3.tick_params(axis=u'both', which=u'both', length=0)
                _ = [ax3.spines[s].set_visible(False) for s in ax3.spines]

                divider = make_axes_locatable(ax3)
                cax3 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im3, cax=cax3)

                # Display extensions ---
                ax4 = fig.add_subplot(gs[1, 1])
                ax4.set_xticks([])
                ax4.set_yticks([])
                _ = [ax4.spines[s].set_visible(False) for s in ax4.spines]

                sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad),
                                                           ncols=1,
                                                           subplot_spec=gs[1,
                                                                           1],
                                                           hspace=0.03)

                # The [::-1] is needed to put the fist extension in the bottom
                for i, ext in enumerate(slit_response_ad[::-1]):

                    ext_data, ext_mask, ext_variance = _transpose_if_needed(
                        ext.data,
                        ext.mask,
                        ext.variance,
                        transpose=dispaxis == 1)

                    ext_data = np.ma.masked_array(ext_data, mask=ext_mask)

                    sub_ax = fig.add_subplot(sub_gs4[i])

                    im4 = sub_ax.imshow(ext_data,
                                        origin="lower",
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap=palette)

                    sub_ax.set_xlim(0, ext_data.shape[1])
                    sub_ax.set_xticks([])
                    sub_ax.set_ylim(0, ext_data.shape[0])
                    sub_ax.set_yticks([ext_data.shape[0] // 2])

                    sub_ax.set_yticklabels(
                        ["Ext {}".format(len(slit_response_ad) - i - 1)],
                        fontsize=6)

                    _ = [
                        sub_ax.spines[s].set_visible(False)
                        for s in sub_ax.spines
                    ]

                    if i == 0:
                        sub_ax.set_title(
                            "Multi-extension\n Slit Response Function")

                divider = make_axes_locatable(ax4)
                cax4 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im4, cax=cax4)

                # Display Signal-To-Noise Ratio ---
                snr = data / np.sqrt(variance)

                norm = vis.ImageNormalize(snr[~snr.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                ax5 = fig.add_subplot(gs[0, 2])

                im5 = ax5.imshow(snr,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax5.set_title("Mosaicked Data SNR", fontsize=10)
                ax5.set_xlim(-1, data.shape[1])
                ax5.set_xticks([])
                ax5.set_ylim(-1, data.shape[0])
                ax5.set_yticks(bin_center)
                ax5.tick_params(axis=u'both', which=u'both', length=0)

                ax5.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax5.spines[s].set_visible(False) for s in ax5.spines]
                _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax5)
                cax5 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im5, cax=cax5)

                # Display Signal-To-Noise Ratio of Slit Illumination ---
                slit_response_snr = np.ma.masked_array(
                    slit_response_data / np.sqrt(slit_response_var),
                    mask=slit_response_mask)

                ax6 = fig.add_subplot(gs[1, 2])

                im6 = ax6.imshow(slit_response_snr,
                                 origin="lower",
                                 vmin=norm.vmin,
                                 vmax=norm.vmax,
                                 cmap=palette)

                ax6.set_xlim(0, slit_response_snr.shape[1])
                ax6.set_xticks([])
                ax6.set_ylim(0, slit_response_snr.shape[0])
                ax6.set_yticks([])
                ax6.set_title("Reconstructed\n Slit Response SNR")

                _ = [ax6.spines[s].set_visible(False) for s in ax6.spines]

                divider = make_axes_locatable(ax6)
                cax6 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im6, cax=cax6)

                # Save plots ---
                fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5)
                fname = slit_response_ad.filename.replace(".fits", ".png")
                log.info("Saving plots to {}".format(fname))
                plt.savefig(fname)

        return ad_outputs
def test_split_mosaic_into_extensions(request):
    """
    Tests helper function that split a mosaicked data into several extensions
    based on another multi-extension file that contains gWCS.
    """
    astrofaker = pytest.importorskip("astrofaker")

    ad = astrofaker.create('GMOS-S')
    ad.init_default_extensions(binning=2)

    ad = transform.add_mosaic_wcs(ad, geotable)
    ad = gt.trim_to_data_section(
        ad, keyword_comments={'NAXIS1': "", 'NAXIS2': "", 'DATASEC': "",
                              'TRIMSEC': "", 'CRPIX1': "", 'CRPIX2': ""})

    for i, ext in enumerate(ad):
        x1 = ext.detector_section().x1
        x2 = ext.detector_section().x2
        xb = ext.detector_x_bin()

        data = np.arange(x1 // xb, x2 // xb)[np.newaxis, :]
        data = np.repeat(data, ext.data.shape[0], axis=0)
        data = data + 0.1 * (0.5 - np.random.random(data.shape))

        ext.data = data

    mosaic_ad = transform.resample_from_wcs(
        ad, "mosaic", attributes=None, order=1, process_objcat=False)

    mosaic_ad[0].data = np.pad(mosaic_ad[0].data, 10, mode='edge')

    mosaic_ad[0].hdr[mosaic_ad._keyword_for('data_section')] = \
        '[1:{},1:{}]'.format(*mosaic_ad[0].shape[::-1])

    ad2 = primitives_gmos_longslit._split_mosaic_into_extensions(
        ad, mosaic_ad, border_size=10)

    if request.config.getoption("--do-plots"):

        palette = copy(plt.cm.viridis)
        palette.set_bad('r', 1)

        fig = plt.figure(num="Test: Split Mosaic Into Extensions", figsize=(8, 6.5), dpi=120)
        fig.suptitle("Test Split Mosaic Into Extensions\n Difference between"
                     " input and mosaicked/demosaicked data")

        gs = fig.add_gridspec(nrows=4, ncols=len(ad) // 3, wspace=0.1, height_ratios=[1, 1, 1, 0.1])

        for i, (ext, ext2) in enumerate(zip(ad, ad2)):

            data1 = ext.data
            data2 = ext2.data
            diff = np.ma.masked_array(data1 - data2, mask=np.abs(data1 - data2) > 1)
            height, width = data1.shape

            row = i // 4
            col = i % 4

            ax = fig.add_subplot(gs[row, col])
            ax.set_title("Ext {}".format(i + 1))
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_yticks([])
            ax.set_yticklabels([])
            _ = [ax.spines[s].set_visible(False) for s in ax.spines]

            if col == 0:
                ax.set_ylabel("Det {}".format(row + 1))

            sub_gs = gridspec.GridSpecFromSubplotSpec(2, 2, ax, wspace=0.05, hspace=0.05)

            for j in range(4):
                sx = fig.add_subplot(sub_gs[j])
                im = sx.imshow(diff, origin='lower', cmap=palette, vmin=-0.1, vmax=0.1)

                sx.set_xticks([])
                sx.set_yticks([])
                sx.set_xticklabels([])
                sx.set_yticklabels([])
                _ = [sx.spines[s].set_visible(False) for s in sx.spines]

                if j == 0:
                    sx.set_xlim(0, 25)
                    sx.set_ylim(height - 25, height)

                if j == 1:
                    sx.set_xlim(width - 25, width)
                    sx.set_ylim(height - 25, height)

                if j == 2:
                    sx.set_xlim(0, 25)
                    sx.set_ylim(0, 25)

                if j == 3:
                    sx.set_xlim(width - 25, width)
                    sx.set_ylim(0, 25)

        cax = fig.add_subplot(gs[3, :])
        cbar = plt.colorbar(im, cax=cax, orientation="horizontal")
        cbar.set_label("Difference levels")

        os.makedirs(PLOT_PATH, exist_ok=True)

        fig.savefig(
            os.path.join(PLOT_PATH, "test_split_mosaic_into_extensions.png"))

    # Actual test ----
    for i, (ext, ext2) in enumerate(zip(ad, ad2)):
        data1 = np.ma.masked_array(ext.data[1:-1, 1:-1], mask=ext.mask)
        data2 = np.ma.masked_array(ext2.data[1:-1, 1:-1], mask=ext2.mask)

        np.testing.assert_almost_equal(data1, data2, decimal=1)
Beispiel #3
0
    def QECorrect(self, adinputs=None, **params):
        """
        This primitive applies a wavelength-dependent QE correction to
        a 2D spectral image, based on the wavelength solution of an
        associated processed_arc.

        It is only designed to work on FLATs, and therefore unmosaicked data.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        arc : {None, AstroData, str}
            Arc(s) with distortion map.
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        sfx = params["suffix"]
        arc = params["arc"]
        use_iraf = params["use_iraf"]
        do_cal = params["do_cal"]

        if do_cal == 'skip':
            log.warning("QE correction has been turned off.")
            return adinputs

        # Get a suitable arc frame (with distortion map) for every science AD
        if arc is None:
            arc_list = self.caldb.get_processed_arc(adinputs)
        else:
            arc_list = (arc, None)

        # Provide an arc AD object for every science frame, and an origin
        for ad, arc, origin in zip(
                *gt.make_lists(adinputs, *arc_list, force_ad=(1, ))):
            if ad.phu.get(timestamp_key):
                log.warning(f"{ad.filename}: already processed by QECorrect. "
                            "Continuing.")
                continue

            if 'e2v' in ad.detector_name(pretty=True):
                log.stdinfo(f"{ad.filename} has the e2v CCDs, so no QE "
                            "correction is necessary")
                continue

            if self.timestamp_keys['mosaicDetectors'] in ad.phu:
                log.warning(f"{ad.filename} has been processed by mosaic"
                            "Detectors so QECorrect cannot be run")
                continue

            # Determines whether to multiply or divide by QE correction
            is_flat = 'FLAT' in ad.tags

            # If the arc's binning doesn't match, we may still be able to
            # fall back to the approximate solution
            xbin, ybin = ad.detector_x_bin(), ad.detector_y_bin()
            if arc is not None and (arc.detector_x_bin() != xbin
                                    or arc.detector_y_bin() != ybin):
                log.warning("Science frame and arc have different binnings.")
                arc = None

            # The plan here is to attach the mosaic gWCS to the science frame,
            # apply an origin shift to put it in the frame of the arc, and
            # then use the arc's WCS to get the wavelength. If there's no arc,
            # we just use the science frame's WCS.
            # Since we're going to change that WCS, store it for restoration.
            original_wcs = [ext.wcs for ext in ad]
            try:
                transform.add_mosaic_wcs(ad, geotable)
            except ValueError:
                log.warning(f"{ad.filename} already has a 'mosaic' coordinate"
                            "frame. This is unexpected but I'll continue.")

            if arc is None:
                if 'sq' in self.mode or do_cal == 'force':
                    raise OSError(f"No processed arc listed for {ad.filename}")
                else:
                    log.warning(f"{ad.filename}: no arc was specified. Using "
                                "wavelength solution in science frame.")
            else:
                # OK, we definitely want to try to do this, get a wavelength solution
                origin_str = f" (obtained from {origin})" if origin else ""
                log.stdinfo(f"{ad.filename}: using the arc {arc.filename}"
                            f"{origin_str}")
                if self.timestamp_keys[
                        'determineWavelengthSolution'] not in arc.phu:
                    msg = f"Arc {arc.filename} (for {ad.filename} has not been wavelength calibrated."
                    if 'sq' in self.mode or do_cal == 'force':
                        raise IOError(msg)
                    else:
                        log.warning(msg)

                # We'll be modifying this
                arc_wcs = deepcopy(arc[0].wcs)
                if 'distortion_corrected' not in arc_wcs.available_frames:
                    msg = f"Arc {arc.filename} (for {ad.filename}) has no distortion model."
                    if 'sq' in self.mode or do_cal == 'force':
                        raise OSError(msg)
                    else:
                        log.warning(msg)

                # NB. At this point, we could have an arc that has no good
                # wavelength solution nor distortion correction. But we will
                # use its WCS rather than the science frame's because it must
                # have been supplied by the user.

                # This is GMOS so no need to be as generic as distortionCorrect
                ad_detsec = ad.detector_section()
                arc_detsec = arc.detector_section()[0]
                if (ad_detsec[0].x1, ad_detsec[-1].x2) != (arc_detsec.x1,
                                                           arc_detsec.x2):
                    raise ValueError("Cannot process the offsets between "
                                     f"{ad.filename} and {arc.filename}")

                yoff1 = arc_detsec.y1 - ad_detsec[0].y1
                yoff2 = arc_detsec.y2 - ad_detsec[0].y2
                arc_ext_shapes = [(ext.shape[0] - yoff1 + yoff2, ext.shape[1])
                                  for ext in ad]
                arc_corners = np.concatenate([
                    transform.get_output_corners(ext.wcs.get_transform(
                        ext.wcs.input_frame, 'mosaic'),
                                                 input_shape=arc_shape,
                                                 origin=(yoff1, 0))
                    for ext, arc_shape in zip(ad, arc_ext_shapes)
                ],
                                             axis=1)
                arc_origin = tuple(
                    np.ceil(min(corners)) for corners in arc_corners)

                # So this is what was applied to the ARC to get the
                # mosaic frame to its pixel frame, in which the distortion
                # correction model was calculated. Convert coordinates
                # from python order to Model order.
                origin_shift = reduce(
                    Model.__and__,
                    [models.Shift(-origin) for origin in arc_origin[::-1]])
                arc_wcs.insert_transform(arc_wcs.input_frame,
                                         origin_shift,
                                         after=True)

            array_info = gt.array_information(ad)
            if array_info.detector_shape == (1, 3):
                ccd2_indices = array_info.extensions[1]
            else:
                raise ValueError(
                    f"{ad.filename} does not have 3 separate detectors")

            for index, ext in enumerate(ad):
                if index in ccd2_indices:
                    continue

                # Use the WCS in the extension if we don't have an arc,
                # otherwise use the arc's mosaic->world transformation
                if arc is None:
                    trans = ext.wcs.forward_transform
                else:
                    trans = (ext.wcs.get_transform(ext.wcs.input_frame,
                                                   'mosaic')
                             | arc_wcs.forward_transform)

                ygrid, xgrid = np.indices(ext.shape)
                # TODO: want with_units
                waves = trans(xgrid,
                              ygrid)[0] * u.nm  # Wavelength always axis 0

                # Tapering required to prevent QE correction from blowing up
                # at the extremes (remember, this is a ratio, not the actual QE)
                # We use half-Gaussians to taper
                taper = np.ones_like(ext.data)
                taper_locut, taper_losig = 350 * u.nm, 25 * u.nm
                taper_hicut, taper_hisig = 1200 * u.nm, 200 * u.nm
                taper[waves < taper_locut] = np.exp(-(
                    (waves[waves < taper_locut] - taper_locut) /
                    taper_losig)**2)
                taper[waves > taper_hicut] = np.exp(-(
                    (waves[waves > taper_hicut] - taper_hicut) /
                    taper_hisig)**2)
                try:
                    qe_correction = (qeModel(ext, use_iraf=use_iraf)(
                        (waves / u.nm).to(u.dimensionless_unscaled).value).
                                     astype(np.float32) - 1) * taper + 1
                except TypeError:  # qeModel() returns None
                    msg = f"No QE correction found for {ad.filename} extension {ext.id}"
                    if 'sq' in self.mode:
                        raise ValueError(msg)
                    else:
                        log.warning(msg)
                        continue
                log.stdinfo(f"Mean relative QE of extension {ext.id} is "
                            f"{qe_correction.mean():.5f}")
                if not is_flat:
                    qe_correction = 1. / qe_correction
                ext.multiply(qe_correction)

            for ext, orig_wcs in zip(ad, original_wcs):
                ext.wcs = orig_wcs

            # Timestamp and update the filename
            gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
            ad.update_filename(suffix=sfx, strip=True)

        return adinputs
Beispiel #4
0
    def QECorrect(self, adinputs=None, **params):
        """
        This primitive applies a wavelength-dependent QE correction to
        a 2D spectral image, based on the wavelength solution of an
        associated processed_arc.

        It is only designed to work on FLATs, and therefore unmosaicked data.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        arc : {None, AstroData, str}
            Arc(s) with distortion map.
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        sfx = params["suffix"]
        arc = params["arc"]

        # Get a suitable arc frame (with distortion map) for every science AD
        if arc is None:
            self.getProcessedArc(adinputs, refresh=False)
            arc_list = self._get_cal(adinputs, 'processed_arc')
        else:
            arc_list = arc

        for ad, arc in zip(*gt.make_lists(adinputs, arc_list, force_ad=True)):
            if ad.phu.get(timestamp_key):
                log.warning("No changes will be made to {}, since it has "
                            "already been processed by QECorrect".
                            format(ad.filename))
                continue

            if 'e2v' in ad.detector_name(pretty=True):
                log.stdinfo(f"{ad.filename} has the e2v CCDs, so no QE "
                            "correction is necessary")
                continue

            if self.timestamp_keys['mosaicDetectors'] in ad.phu:
                log.warning(f"{ad.filename} has been processed by mosaic"
                            "Detectors so QECorrect cannot be run")
                continue

            # Determines whether to multiply or divide by QE correction
            is_flat = 'FLAT' in ad.tags

            # If the arc's binning doesn't match, we may still be able to
            # fall back to the approximate solution
            xbin, ybin = ad.detector_x_bin(), ad.detector_y_bin()
            if arc is not None and (arc.detector_x_bin() != xbin or
                                    arc.detector_y_bin() != ybin):
                log.warning("Science frame {} and arc {} have different binnings,"
                            "so cannot use arc".format(ad.filename, arc.filename))
                arc = None

            # The plan here is to attach the mosaic gWCS to the science frame,
            # apply an origin shift to put it in the frame of the arc, and
            # then use the arc's WCS to get the wavelength. If there's no arc,
            # we just use the science frame's WCS.
            # Since we're going to change that WCS, store it for restoration.
            original_wcs = [ext.wcs for ext in ad]
            try:
                transform.add_mosaic_wcs(ad, geotable)
            except ValueError:
                log.warning(f"{ad.filename} already has a 'mosaic' coordinate"
                            "frame. This is unexpected but I'll continue.")

            if arc is None:
                if 'sq' in self.mode:
                    raise OSError(f"No processed arc listed for {ad.filename}")
                else:
                    log.warning(f"No arc supplied for {ad.filename}")
            else:
                # OK, we definitely want to try to do this, get a wavelength solution
                if self.timestamp_keys['determineWavelengthSolution'] not in arc.phu:
                    msg = f"Arc {arc.filename} (for {ad.filename} has not been wavelength calibrated."
                    if 'sq' in self.mode:
                        raise IOError(msg)
                    else:
                        log.warning(msg)

                # We'll be modifying this
                arc_wcs = deepcopy(arc[0].wcs)
                if 'distortion_corrected' not in arc_wcs.available_frames:
                    msg = f"Arc {arc.filename} (for {ad.filename}) has no distortion model."
                    if 'sq' in self.mode:
                        raise OSError(msg)
                    else:
                        log.warning(msg)

                # NB. At this point, we could have an arc that has no good
                # wavelength solution nor distortion correction. But we will
                # use its WCS rather than the science frame's because it must
                # have been supplied by the user.

                # This is GMOS so no need to be as generic as distortionCorrect
                ad_detsec = ad.detector_section()
                arc_detsec = arc.detector_section()[0]
                if (ad_detsec[0].x1, ad_detsec[-1].x2) != (arc_detsec.x1, arc_detsec.x2):
                    raise ValueError("I don't know how to process the "
                                     f"offsets between {ad.filename} "
                                     f"and {arc.filename}")

                yoff1 = arc_detsec.y1 - ad_detsec[0].y1
                yoff2 = arc_detsec.y2 - ad_detsec[0].y2
                arc_ext_shapes = [(ext.shape[0] - yoff1 + yoff2,
                                   ext.shape[1]) for ext in ad]
                arc_corners = np.concatenate([transform.get_output_corners(
                    ext.wcs.get_transform(ext.wcs.input_frame, 'mosaic'),
                    input_shape=arc_shape, origin=(yoff1, 0))
                    for ext, arc_shape in zip(ad, arc_ext_shapes)], axis=1)
                arc_origin = tuple(np.ceil(min(corners)) for corners in arc_corners)

                # So this is what was applied to the ARC to get the
                # mosaic frame to its pixel frame, in which the distortion
                # correction model was calculated. Convert coordinates
                # from python order to Model order.
                origin_shift = reduce(Model.__and__, [models.Shift(-origin)
                                                      for origin in arc_origin[::-1]])
                arc_wcs.insert_transform(arc_wcs.input_frame, origin_shift, after=True)

            array_info = gt.array_information(ad)
            if array_info.detector_shape == (1, 3):
                ccd2_indices = array_info.extensions[1]
            else:
                raise ValueError(f"{ad.filename} does not have 3 separate detectors")

            for index, ext in enumerate(ad):
                if index in ccd2_indices:
                    continue

                # Use the WCS in the extension if we don't have an arc,
                # otherwise use the arc's mosaic->world transformation
                if arc is None:
                    trans = ext.wcs.forward_transform
                else:
                    trans = (ext.wcs.get_transform(ext.wcs.input_frame, 'mosaic') |
                             arc_wcs.forward_transform)

                ygrid, xgrid = np.indices(ext.shape)
                # TODO: want with_units
                waves = trans(xgrid, ygrid)[0] * u.nm  # Wavelength always axis 0
                try:
                    qe_correction = qeModel(ext)((waves / u.nm).to(u.dimensionless_unscaled).value).astype(np.float32)
                except TypeError:  # qeModel() returns None
                    msg = "No QE correction found for {}:{}".format(ad.filename, ext.hdr['EXTVER'])
                    if 'sq' in self.mode:
                        raise ValueError(msg)
                    else:
                        log.warning(msg)
                log.stdinfo("Mean relative QE of EXTVER {} is {:.5f}".
                             format(ext.hdr['EXTVER'], qe_correction.mean()))
                if not is_flat:
                    qe_correction = 1. / qe_correction
                qe_correction[qe_correction < 0] = 0
                qe_correction[qe_correction > 10] = 0
                ext.multiply(qe_correction)

            for ext, orig_wcs in zip(ad, original_wcs):
                ext.wcs = orig_wcs

            # Timestamp and update the filename
            gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
            ad.update_filename(suffix=sfx, strip=True)

        return adinputs
Beispiel #5
0
    def mosaicDetectors(self, adinputs=None, **params):
        """
        This primitive does a full mosaic of all the arrays in an AD object.
        An appropriate geometry_conf.py module containing geometric information
        is required.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files.
        sci_only: bool
            mosaic only SCI image data. Default is False
        order: int (1-5)
            order of spline interpolation
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params['suffix']
        order = params['order']
        attributes = ['data'] if params['sci_only'] else None
        geotable = import_module('.geometry_conf', self.inst_lookups)

        adoutputs = []
        for ad in adinputs:
            if ad.phu.get(timestamp_key):
                log.warning("No changes will be made to {}, since it has "
                            "already been processed by mosaicDetectors".format(
                                ad.filename))
                adoutputs.append(ad)
                continue

            if len(ad) == 1:
                log.warning("{} has only one extension, so there's nothing "
                            "to mosaic".format(ad.filename))
                adoutputs.append(ad)
                continue

            if not all(
                    np.issubdtype(ext.data.dtype, np.floating) for ext in ad):
                log.warning("Cannot mosaic {} with non-floating point data. "
                            "Use tileArrays instead".format(ad.filename))
                adoutputs.append(ad)
                continue

            transform.add_mosaic_wcs(ad, geotable)

            # If there's an overscan section in the data, this will crash, but
            # we can catch that, trim, and try again. Don't catch anything else
            try:
                ad_out = transform.resample_from_wcs(ad,
                                                     "mosaic",
                                                     attributes=attributes,
                                                     order=order,
                                                     process_objcat=False)
            except ValueError as e:
                if 'data sections' in repr(e):
                    ad = gt.trim_to_data_section(ad, self.keyword_comments)
                    ad_out = transform.resample_from_wcs(ad,
                                                         "mosaic",
                                                         attributes=attributes,
                                                         order=order,
                                                         process_objcat=False)
                else:
                    raise e

            ad_out.orig_filename = ad.filename
            gt.mark_history(ad_out,
                            primname=self.myself(),
                            keyword=timestamp_key)
            ad_out.update_filename(suffix=suffix, strip=True)
            adoutputs.append(ad_out)

        return adoutputs