Ejemplo n.º 1
0
    def build_nirspec_lamp_output_wcs(self):
        """
        Create a spatial/spectral WCS output frame for NIRSpec lamp mode

        Creates output frame by linearly fitting x_msa, y_msa along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """
        model = self.input_models[0]
        wcs = model.meta.wcs
        bbox = wcs.bounding_box
        grid = wcstools.grid_from_bounding_box(bbox)
        x_msa, y_msa, lam = np.array(wcs(*grid))
        # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion.  The
        # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order
        spectral_axis = find_dispersion_axis(model)
        spatial_axis = spectral_axis ^ 1

        # Compute the wavelength array, trimming NaNs from the ends
        # In many cases, a whole slice is NaNs, so ignore those warnings
        warnings.simplefilter("ignore")
        wavelength_array = np.nanmedian(lam, axis=spectral_axis)
        warnings.resetwarnings()
        wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

        # Find the center ra and dec for this slit at central wavelength
        lam_center_index = int((bbox[spectral_axis][1] -
                                bbox[spectral_axis][0]) / 2)
        x_msa_array = x_msa.T[lam_center_index]
        y_msa_array = y_msa.T[lam_center_index]
        x_msa_array = x_msa_array[~np.isnan(x_msa_array)]
        y_msa_array = y_msa_array[~np.isnan(y_msa_array)]

        # Estimate and fit the spatial sampling
        fitter = LinearLSQFitter()
        fit_model = Linear1D()
        xstop = x_msa_array.shape[0] / self.pscale_ratio
        xstep = 1 / self.pscale_ratio
        ystop = y_msa_array.shape[0] / self.pscale_ratio
        ystep = 1 / self.pscale_ratio
        pix_to_x_msa = fitter(fit_model, np.arange(0, xstop, xstep), x_msa_array)
        pix_to_y_msa = fitter(fit_model, np.arange(0, ystop, ystep), y_msa_array)

        step = 1 / self.pscale_ratio
        stop = wavelength_array.shape[0] / self.pscale_ratio
        points = np.arange(0, stop, step)
        pix_to_wavelength = Tabular1D(points=points,
                                      lookup_table=wavelength_array,
                                      bounds_error=False, fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is descending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(0, stop, step)

        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False, fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))
        mapping.inverse = Mapping((2, 1))

        # The final transform
        # define the output wcs
        transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.Frame2D(name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1))
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
                                unit=(u.micron,), axes_names=('wavelength',))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform),
                    (world, None)]

        output_wcs = WCS(pipeline)

        # Compute the output array size and bounding box
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio))
        x_size = len(x_msa_array)
        output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio))
        # turn the size into a numpy shape in (y, x) order
        output_wcs.array_shape = output_array_size[::-1]
        output_wcs.pixel_shape = output_array_size
        bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1])
        output_wcs.bounding_box = bounding_box

        return output_wcs
Ejemplo n.º 2
0
    def build_nirspec_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS covering footprint of the input
        """
        all_wcs = [m.meta.wcs for m in self.input_models if m is not refmodel]
        if refmodel:
            all_wcs.insert(0, refmodel.meta.wcs)
        else:
            refmodel = self.input_models[0]

        refwcs = refmodel.meta.wcs

        s2d = refwcs.get_transform('slit_frame', 'detector')
        d2s = refwcs.get_transform('detector', 'slit_frame')
        s2w = refwcs.get_transform('slit_frame', 'world')

        # estimate position of the target without relying in the meta.target:
        bbox = refwcs.bounding_box

        grid = wcstools.grid_from_bounding_box(bbox)
        _, s, lam = np.array(d2s(*grid))
        sd = s * refmodel.data
        ld = lam * refmodel.data
        good_s = np.isfinite(sd)
        if np.any(good_s):
            total = np.sum(refmodel.data[good_s])
            wmean_s = np.sum(sd[good_s]) / total
            wmean_l = np.sum(ld[good_s]) / total
        else:
            wmean_s = 0.5 * (refmodel.slit_ymax - refmodel.slit_ymin)
            wmean_l = d2s(*np.mean(bbox, axis=1))[2]

        targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l)

        ref_lam = _find_nirspec_output_sampling_wavelengths(
            all_wcs,
            targ_ra, targ_dec
        )
        ref_lam = np.array(ref_lam)

        n_lam = ref_lam.size
        if not n_lam:
            raise ValueError("Not enough data to construct output WCS.")

        x_slit = np.zeros(n_lam)
        lam = 1e-6 * ref_lam

        # Find the spatial pixel scale:
        y_slit_min, y_slit_max = self._max_virtual_slit_extent(all_wcs, targ_ra, targ_dec)

        nsampl = 50
        xy_min = s2d(
            nsampl * [0],
            nsampl * [y_slit_min],
            lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )]
        )
        xy_max = s2d(
            nsampl * [0],
            nsampl * [y_slit_max],
            lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )]
        )

        good = np.logical_and(np.isfinite(xy_min), np.isfinite(xy_max))
        if not np.any(good):
            raise ValueError("Error estimating output WCS pixel scale.")

        xy1 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymin), lam)
        xy2 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymax), lam)
        xylen = np.nanmax(np.linalg.norm(np.array(xy1) - np.array(xy2), axis=0)) + 1
        pscale = (refmodel.slit_ymax - refmodel.slit_ymin) / xylen

        # compute image span along Y-axis (length of the slit in the detector plane)
        # det_slit_span = np.linalg.norm(np.subtract(xy_max, xy_min))
        det_slit_span = np.nanmax(np.linalg.norm(np.subtract(xy_max, xy_min), axis=0))
        ny = int(np.ceil(det_slit_span * self.pscale_ratio + 0.5)) + 1

        border = 0.5 * (ny - det_slit_span * self.pscale_ratio) - 0.5

        if xy_min[1][1] < xy_max[1][1]:
            y_slit_model = Linear1D(
                slope=pscale / self.pscale_ratio,
                intercept=y_slit_min - border * pscale * self.pscale_ratio
            )
        else:
            y_slit_model = Linear1D(
                slope=-pscale / self.pscale_ratio,
                intercept=y_slit_max + border * pscale * self.pscale_ratio
            )

        # extrapolate 1/2 pixel at the edges and make tabular model w/inverse:
        lam = lam.tolist()
        pixel_coord = list(range(n_lam))

        if len(pixel_coord) > 1:
            # left:
            slope = (lam[1] - lam[0]) / pixel_coord[1]
            lam.insert(0, -0.5 * slope + lam[0])
            pixel_coord.insert(0, -0.5)
            # right:
            slope = (lam[-1] - lam[-2]) / (pixel_coord[-1] - pixel_coord[-2])
            lam.append(slope * (pixel_coord[-1] + 0.5) + lam[-2])
            pixel_coord.append(pixel_coord[-1] + 0.5)

        else:
            lam = 3 * lam
            pixel_coord = [-0.5, 0, 0.5]

        wavelength_transform = Tabular1D(points=pixel_coord,
                                         lookup_table=lam,
                                         bounds_error=False, fill_value=np.nan)
        wavelength_transform.inverse = Tabular1D(points=lam,
                                                 lookup_table=pixel_coord,
                                                 bounds_error=False,
                                                 fill_value=np.nan)
        self.data_size = (ny, len(ref_lam))

        # Construct the final transform
        mapping = Mapping((0, 1, 0))
        mapping.inverse = Mapping((2, 1))
        out_det2slit = mapping | Identity(1) & y_slit_model & wavelength_transform

        # Create coordinate frames
        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1),
                                  unit=("", ""), axes_names=('x_slit', 'y_slit'))
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
                                unit=(u.micron,), axes_names=('wavelength',))
        slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame')
        sky = cf.CelestialFrame(name='sky', axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, out_det2slit), (slit_frame, s2w), (world, None)]
        output_wcs = WCS(pipeline)

        # Compute bounding box and output array shape.  Add one to the y (slit)
        # height to account for the half pixel at top and bottom due to pixel
        # coordinates being centers of pixels
        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box
        output_wcs.array_shape = self.data_size

        return output_wcs
Ejemplo n.º 3
0
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame using all the input models

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """

        # for each input model convert slit x,y to ra,dec,lam
        # use first input model to set spatial scale
        # use center of appended ra and dec arrays to set up
        # center of final ra,dec
        # append all ra,dec, wavelength array for each slit
        # use first model to initialize wavelength array
        # append wavelengths that fall outside the endpoint of
        # of wavelength array when looping over additional data

        all_wavelength = []
        all_ra_slit = []
        all_dec_slit = []

        for im, model in enumerate(self.input_models):
            wcs = model.meta.wcs
            bbox = wcs.bounding_box
            grid = wcstools.grid_from_bounding_box(bbox)
            ra, dec, lam = np.array(wcs(*grid))
            # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion.  The
            # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order
            spectral_axis = find_dispersion_axis(model)
            spatial_axis = spectral_axis ^ 1

            # Compute the wavelength array, trimming NaNs from the ends
            # In many cases, a whole slice is NaNs, so ignore those warnings
            warnings.simplefilter("ignore")
            wavelength_array = np.nanmedian(lam, axis=spectral_axis)
            warnings.resetwarnings()
            wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

            # We need to estimate the spatial sampling to use for the output WCS.
            # Tt is assumed the spatial sampling is the same for all the input
            # models. So we can use the first input model to set the spatial
            # sampling.

            # Steps to do this for first input model:
            # 1. find the middle of the spectrum in wavelength
            # 2. Pull out the ra and dec at the center of the slit.
            # 3. Find the mean ra,dec and the center of the slit this will
            #    represent the tangent point
            # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan
            # 5. using x_tan, y_tan perform a linear fit to find spatial sampling
            # first input model sets initializes wavelength array and defines
            # the spatial scale of the output wcs
            if im == 0:
                all_wavelength = np.append(all_wavelength, wavelength_array)

                # find the center ra and dec for this slit at central wavelength
                lam_center_index = int((bbox[spectral_axis][1] -
                                        bbox[spectral_axis][0]) / 2)
                if spatial_axis == 0:  # MIRI LRS, the WCS x axis is spatial
                    ra_slice = ra[lam_center_index, :]
                    dec_slice = dec[lam_center_index, :]
                else:
                    ra_slice = ra[:, lam_center_index]
                    dec_slice = dec[:, lam_center_index]
                # wrap RA if near zero
                ra_center_pt = np.nanmean(wrap_ra(ra_slice))
                dec_center_pt = np.nanmean(dec_slice)

                # convert ra and dec to tangent projection
                tan = Pix2Sky_TAN()
                native2celestial = RotateNative2Celestial(ra_center_pt, dec_center_pt, 180)
                undist2sky1 = tan | native2celestial
                # Filter out RuntimeWarnings due to computed NaNs in the WCS
                warnings.simplefilter("ignore")
                # at this center of slit find x,y tangent projection - x_tan, y_tan
                x_tan, y_tan = undist2sky1.inverse(ra, dec)
                warnings.resetwarnings()

                # pull out data from center
                if spectral_axis == 0:  # MIRI LRS, the WCS x axis is spatial
                    x_tan_array = x_tan.T[lam_center_index]
                    y_tan_array = y_tan.T[lam_center_index]
                else:
                    x_tan_array = x_tan[lam_center_index]
                    y_tan_array = y_tan[lam_center_index]

                x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
                y_tan_array = y_tan_array[~np.isnan(y_tan_array)]

                # estimate the spatial sampling
                fitter = LinearLSQFitter()
                fit_model = Linear1D()
                xstop = x_tan_array.shape[0] / self.pscale_ratio
                xstep = 1 / self.pscale_ratio
                ystop = y_tan_array.shape[0] / self.pscale_ratio
                ystep = 1 / self.pscale_ratio
                pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep), x_tan_array)
                pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep), y_tan_array)

            # append all ra and dec values to use later to find min and max
            # ra and dec
            ra_use = ra[~np.isnan(ra)].flatten()
            dec_use = dec[~np.isnan(dec)].flatten()
            all_ra_slit = np.append(all_ra_slit, ra_use)
            all_dec_slit = np.append(all_dec_slit, dec_use)

            # now check wavelength array to see if we need to add to it
            this_minw = np.min(wavelength_array)
            this_maxw = np.max(wavelength_array)
            all_minw = np.min(all_wavelength)
            all_maxw = np.max(all_wavelength)
            if this_minw < all_minw:
                addpts = wavelength_array[wavelength_array < all_minw]
                all_wavelength = np.append(all_wavelength, addpts)
            if this_maxw > all_maxw:
                addpts = wavelength_array[wavelength_array > all_maxw]
                all_wavelength = np.append(all_wavelength, addpts)

        # done looping over set of models
        all_ra = np.hstack(all_ra_slit)
        all_dec = np.hstack(all_dec_slit)
        all_wave = np.hstack(all_wavelength)
        all_wave = all_wave[~np.isnan(all_wave)]
        all_wave = np.sort(all_wave, axis=None)
        # Tabular interpolation model, pixels -> lambda
        wavelength_array = np.unique(all_wave)
        # Check if the data is MIRI LRS FIXED Slit. If it is then
        # the wavelength array needs to be flipped so that the resampled
        # dispersion direction matches the dispersion direction on the detector.
        if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT':
            wavelength_array = np.flip(wavelength_array, axis=None)

        step = 1 / self.pscale_ratio
        stop = wavelength_array.shape[0] / self.pscale_ratio
        points = np.arange(0, stop, step)
        pix_to_wavelength = Tabular1D(points=points,
                                      lookup_table=wavelength_array,
                                      bounds_error=False, fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is descending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(0, stop, step)

        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False, fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_xtan will be nearly zero, so make sure
        # mapping.inverse uses pix_to_ytan.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.

        swap_xy = np.isclose(pix_to_xtan.slope, 0, atol=1e-8)
        if swap_xy:
            # Account for vertical or horizontal dispersion on detector
            mapping.inverse = Mapping((2, 1) if spatial_axis else (1, 2))

        # The final transform
        # redefine the ra, dec center tangent point to include all data

        # check if all_ra crosses 0 degrees - this makes it hard to
        # define the min and max ra correctly
        all_ra = wrap_ra(all_ra)
        ra_min = np.amin(all_ra)
        ra_max = np.amax(all_ra)
        ra_center_final = (ra_max + ra_min) / 2.0

        dec_min = np.amin(all_dec)
        dec_max = np.amax(all_dec)
        dec_center_final = (dec_max + dec_min) / 2.0

        tan = Pix2Sky_TAN()
        if len(self.input_models) == 1:  # single model use ra_center_pt to be consistent
            # with how resample was done before
            ra_center_final = ra_center_pt
            dec_center_final = dec_center_pt

        native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180)
        undist2sky = tan | native2celestial
        # find the spatial size of the output - same in x,y
        if swap_xy:
            _, x_tan_all = undist2sky.inverse(all_ra, all_dec)
            pix_to_tan_slope = pix_to_ytan.slope
        else:
            x_tan_all, _ = undist2sky.inverse(all_ra, all_dec)
            pix_to_tan_slope = pix_to_xtan.slope

        x_min = np.amin(x_tan_all)
        x_max = np.amax(x_tan_all)
        x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_tan_slope)))
        if swap_xy:
            pix_to_ytan.intercept = -0.5 * (x_size - 1) * pix_to_ytan.slope
        else:
            pix_to_xtan.intercept = -0.5 * (x_size - 1) * pix_to_xtan.slope

        # single model use size of x_tan_array
        # to be consistent with method before
        if len(self.input_models) == 1:
            x_size = len(x_tan_array)

        # define the output wcs
        transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.CelestialFrame(name='sky', axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
                                unit=(u.micron,), axes_names=('wavelength',))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform),
                    (world, None)]

        output_wcs = WCS(pipeline)

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio))
        output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio))

        # turn the size into a numpy shape in (y, x) order
        output_wcs.array_shape = output_array_size[::-1]
        output_wcs.pixel_shape = output_array_size
        bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1])
        output_wcs.bounding_box = bounding_box

        return output_wcs