Example #1
0
def test_parameters_compound_models():
    tan = Pix2Sky_TAN()
    sky_coords = coord.SkyCoord(ra=5.6, dec=-72, unit=u.deg)
    lon_pole = 180 * u.deg
    n2c = RotateNative2Celestial(sky_coords.ra, sky_coords.dec, lon_pole)
    rot = Rotation2D(23)
    m = rot | n2c
def test_compound_return_units():
    """
    Test that return_units on the first model in the chain is respected for the
    input to the second.
    """
    class PassModel(Model):

        inputs = ('x', 'y')
        outputs = ('x', 'y')

        @property
        def input_units(self):
            """ Input units. """
            return {'x': u.deg, 'y': u.deg}

        @property
        def return_units(self):
            """ Output units. """
            return {'x': u.deg, 'y': u.deg}

        def evaluate(self, x, y):
            return x.value, y.value

    cs = Pix2Sky_TAN() | PassModel()

    assert_quantity_allclose(cs(0 * u.deg, 0 * u.deg), (0, 90) * u.deg)
Example #3
0
def compute_spec_transform(fiducial, refwcs):
    """
    Compute a simple transform given a fidicial point in a spatial-spectral wcs.
    """
    cdelt1 = refwcs.wcsinfo.cdelt1 / 3600.
    cdelt2 = refwcs.wcsinfo.cdelt2 / 3600.
    cdelt3 = refwcs.wcsinfo.cdelt3
    roll_ref = refwcs.wcsinfo.roll_ref

    y, x = grid_from_spec_domain(refwcs)
    ra, dec, lam = refwcs(x, y)

    min_lam = np.nanmin(lam)

    offset = Shift(0.) & Shift(0.)
    rot = Rotation2D(roll_ref)
    scale = Scale(cdelt1) & Scale(cdelt2)
    tan = Pix2Sky_TAN()
    skyrot = RotateNative2Celestial(fiducial[0][0], fiducial[0][1], 180.0)
    spatial = offset | rot | scale | tan | skyrot
    spectral = Scale(cdelt3) | Shift(min_lam)
    mapping = Mapping((1, 1, 0), )
    mapping.inverse = Mapping((2, 1))
    transform = mapping | spatial & spectral
    transform.outputs = ('ra', 'dec', 'lamda')
    return transform
Example #4
0
def test_compound_return_units():
    """
    Test that return_units on the first model in the chain is respected for the
    input to the second.
    """
    class PassModel(Model):

        n_inputs = 2
        n_outputs = 2

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        @property
        def input_units(self):
            """ Input units. """
            return {'x0': u.deg, 'x1': u.deg}

        @property
        def return_units(self):
            """ Output units. """
            return {'x0': u.deg, 'x1': u.deg}

        def evaluate(self, x, y):
            return x.value, y.value

    cs = Pix2Sky_TAN() | PassModel()

    assert_quantity_allclose(cs(0 * u.deg, 0 * u.deg), (0, 90) * u.deg)
Example #5
0
def _make_reference_gwcs_wcs(fits_hdr):
    hdr = fits.Header.fromfile(get_pkg_data_filename(fits_hdr))
    fw = fitswcs.WCS(hdr)

    unit_conv = Scale(1.0 / 3600.0, name='arcsec_to_deg_1D')
    unit_conv = unit_conv & unit_conv
    unit_conv.name = 'arcsec_to_deg_2D'

    unit_conv_inv = Scale(3600.0, name='deg_to_arcsec_1D')
    unit_conv_inv = unit_conv_inv & unit_conv_inv
    unit_conv_inv.name = 'deg_to_arcsec_2D'

    c2s = CartesianToSpherical(name='c2s', wrap_lon_at=180)
    s2c = SphericalToCartesian(name='s2c', wrap_lon_at=180)
    c2tan = ((Mapping((0, 1, 2), name='xyz') / Mapping(
        (0, 0, 0), n_inputs=3, name='xxx')) | Mapping((1, 2), name='xtyt'))
    c2tan.name = 'Cartesian 3D to TAN'

    tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') |
             (Const1D(1, name='one') & Identity(2, name='I(2D)')))
    tan2c.name = 'TAN to cartesian 3D'

    tan2c.inverse = c2tan
    c2tan.inverse = tan2c

    aff = AffineTransformation2D(matrix=fw.wcs.cd)

    offx = Shift(-fw.wcs.crpix[0])
    offy = Shift(-fw.wcs.crpix[1])

    s = 5e-6
    scale = Scale(s) & Scale(s)

    det2tan = (offx & offy) | scale | tan2c | c2s | unit_conv_inv

    taninv = s2c | c2tan
    tan = Pix2Sky_TAN()
    n2c = RotateNative2Celestial(fw.wcs.crval[0], fw.wcs.crval[1], 180)
    wcslin = unit_conv | taninv | scale.inverse | aff | tan | n2c

    sky_frm = cf.CelestialFrame(reference_frame=coord.ICRS())
    det_frm = cf.Frame2D(name='detector')
    v2v3_frm = cf.Frame2D(name="v2v3",
                          unit=(u.arcsec, u.arcsec),
                          axes_names=('x', 'y'),
                          axes_order=(0, 1))
    pipeline = [(det_frm, det2tan), (v2v3_frm, wcslin), (sky_frm, None)]

    gw = gwcs.WCS(input_frame=det_frm,
                  output_frame=sky_frm,
                  forward_transform=pipeline)
    gw.crpix = fw.wcs.crpix
    gw.crval = fw.wcs.crval
    gw.bounding_box = ((-0.5, fw.pixel_shape[0] - 0.5),
                       (-0.5, fw.pixel_shape[1] - 0.5))

    return gw
Example #6
0
def map_to_transform(smap):
    # crval1u, crval2u = smap.reference_coordinate.Tx, smap.reference_coordinate.Ty
    # cdelt1u, cdelt2u = smap.scale
    # pcu = smap.rotation_matrix * u.arcsec

    # # First, shift the reference pixel from FITS (1) to Python (0) indexing
    # crpix1, crpix2 = u.Quantity(smap.reference_pixel) - 1 * u.pixel
    # # Then FITS WCS uses the negative of this value as the shift.
    # shiftu = Shift(-crpix1) & Shift(-crpix2)

    # # Next we define the Affine Transform.
    # # This also includes the pixel scale operation by using equivalencies
    # scale_e = {a: u.pixel_scale(scale) for a, scale in zip('xy', smap.scale)}
    # rotu = AffineTransformation2D(pcu, translation=(0, 0) * u.arcsec)
    # rotu.input_units_equivalencies = scale_e

    # # Handle the projection
    # tanu = Pix2Sky_TAN()

    # # Rotate from native spherical to celestial spherical coordinates
    # skyrotu = RotateNative2Celestial(crval1u, crval2u, 180 * u.deg)

    # # Combine the whole pipeline into one compound model
    # transu = shiftu | rotu | tanu | skyrotu

    crpix1u, crpix2u = u.Quantity(smap.reference_pixel) - 1 * u.pixel
    crval1u, crval2u = smap.reference_coordinate.Tx, smap.reference_coordinate.Ty
    cdelt1u, cdelt2u = smap.scale
    pcu = smap.rotation_matrix * u.arcsec
    shiftu = Shift(-crpix1u) & Shift(-crpix2u)
    scaleu = Multiply(cdelt1u) & Multiply(cdelt2u)
    rotu = AffineTransformation2D(pcu, translation=(0, 0) * u.arcsec)
    tanu = Pix2Sky_TAN()
    skyrotu = RotateNative2Celestial(crval1u, crval2u, 180 * u.deg)
    transu = shiftu | scaleu | rotu | tanu | skyrotu
    transu.rename("spatial")

    return transu
Example #7
0
def spatial_model_from_quantity(crpix1,
                                crpix2,
                                cdelt1,
                                cdelt2,
                                pc,
                                crval1,
                                crval2,
                                projection='TAN'):
    """
    Given quantity representations of a HPLx FITS WCS return a model for the
    spatial transform.

    The ordering of ctype1 and ctype2 should be LON, LAT
    """

    # TODO: Find this from somewhere else or extend it or something
    projections = {'TAN': Pix2Sky_TAN()}

    shiftu = Shift(-crpix1) & Shift(-crpix2)
    scale = Multiply(cdelt1) & Multiply(cdelt2)
    rotu = AffineTransformation2D(pc, translation=(0, 0) * u.arcsec)
    tanu = projections[projection]
    skyrotu = RotateNative2Celestial(crval1, crval2, 180 * u.deg)
    return shiftu | scale | rotu | tanu | skyrotu
Example #8
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from bounding_box
        dx0 = refwcs.bounding_box[0][0]
        dx1 = refwcs.bounding_box[0][1]
        dy0 = refwcs.bounding_box[1][0]
        dy1 = refwcs.bounding_box[1][1]
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix,
                                       dy0 + dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0],
                                   dec=slit_center[1],
                                   unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)

        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Example #9
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within bounding box
        bb = refwcs.bounding_box
        det = x, y = wcstools.grid_from_bounding_box(bb,
                                                     step=(1, 1),
                                                     center=True)
        sky = ra, dec, lam = refwcs(*det)
        x_center = int((bb[0][1] - bb[0][0]) / 2)
        y_center = int((bb[1][1] - bb[1][0]) / 2)
        log.debug("Center of bounding box: {}  {}".format(x_center, y_center))

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                                         x[y_center,
                                           x_center - sz + 1:x_center + sz],
                                         bounds_error=False,
                                         fill_value='extrapolate')
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:, x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
                                   dec=interpol_dec(slit_center_pix),
                                   unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(
            resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        slit_rot_angle = slit_rot_angle.value
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)

        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale.value) & Scale(spatial_scale.value)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra.value,
                                        slit_center_sky.dec.value,
                                        lon_pole.value)

        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)
        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Example #10
0
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame using all the input models

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """

        # for each input model convert slit x,y to ra,dec,lam
        # use first input model to set spatial scale
        # use center of appended ra and dec arrays to set up
        # center of final ra,dec
        # append all ra,dec, wavelength array for each slit
        # use first model to initialize wavelenth array
        # append wavelengths that fall outside the endpoint of
        # of wavelength array when looping over additional data

        all_wavelength = []
        all_ra_slit = []
        all_dec_slit = []

        for im, model in enumerate(self.input_models):
            wcs = model.meta.wcs
            bb = wcs.bounding_box
            grid = wcstools.grid_from_bounding_box(bb)
            ra, dec, lam = np.array(wcs(*grid))
            spectral_axis = find_dispersion_axis(model)
            spatial_axis = spectral_axis ^ 1

            # Compute the wavelength array, trimming NaNs from the ends
            # In many cases, a whole slice is NaNs, so ignore those warnings
            warnings.simplefilter("ignore")
            wavelength_array = np.nanmedian(lam, axis=spectral_axis)
            warnings.resetwarnings()
            wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

            # We need to estimate the spatial sampling to use for the output WCS.
            # Tt is assumed the spatial sampling is the same for all the input
            # models. So we can use the first input model to set the spatial
            # sampling.

            # Steps to do this for first input model:
            # 1. find the middle of the spectrum in wavelength
            # 2. Pull out the ra and dec at the center of the slit.
            # 3. Find the mean ra,dec and the center of the slit this will
            #    represent the tangent point
            # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan
            # 5. using x_tan, y_tan perform a linear fit to find spatial sampling
            # first input model sets intializes wavelength array and defines
            # the spatial scale of the output wcs
            if im == 0:
                for iw in wavelength_array:
                    all_wavelength.append(iw)

                lam_center_index = int(
                    (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2)
                if spatial_axis == 0:
                    ra_center = ra[lam_center_index, :]
                    dec_center = dec[lam_center_index, :]
                else:
                    ra_center = ra[:, lam_center_index]
                    dec_center = dec[:, lam_center_index]
                # find the ra and dec for this slit using center of slit
                ra_center_pt = np.nanmean(ra_center)
                dec_center_pt = np.nanmean(dec_center)

                if resample_utils.is_sky_like(model.meta.wcs.output_frame):
                    # convert ra and dec to tangent projection
                    tan = Pix2Sky_TAN()
                    native2celestial = RotateNative2Celestial(
                        ra_center_pt, dec_center_pt, 180)
                    undist2sky1 = tan | native2celestial
                    # Filter out RuntimeWarnings due to computed NaNs in the WCS
                    warnings.simplefilter("ignore")
                    # at this center of slit find x,y tangent projection - x_tan, y_tan
                    x_tan, y_tan = undist2sky1.inverse(ra, dec)
                    warnings.resetwarnings()
                else:
                    # for non sky-like output frames, no need to do tangent plane projections
                    # but we still use the same variables
                    x_tan, y_tan = ra, dec

                # pull out data from center
                if spectral_axis == 0:
                    x_tan_array = x_tan.T[lam_center_index]
                    y_tan_array = y_tan.T[lam_center_index]
                else:
                    x_tan_array = x_tan[lam_center_index]
                    y_tan_array = y_tan[lam_center_index]

                x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
                y_tan_array = y_tan_array[~np.isnan(y_tan_array)]

                # estimate the spatial sampling
                fitter = LinearLSQFitter()
                fit_model = Linear1D()
                xstop = x_tan_array.shape[0] / self.pscale_ratio
                xstep = 1 / self.pscale_ratio
                ystop = y_tan_array.shape[0] / self.pscale_ratio
                ystep = 1 / self.pscale_ratio
                pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep),
                                     x_tan_array)
                pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep),
                                     y_tan_array)

            # append all ra and dec values to use later to find min and max
            # ra and dec
            ra_use = ra.flatten()
            ra_use = ra_use[~np.isnan(ra_use)]
            dec_use = dec.flatten()
            dec_use = dec_use[~np.isnan(dec_use)]
            all_ra_slit.append(ra_use)
            all_dec_slit.append(dec_use)

            # now check wavelength array to see if we need to add to it
            this_minw = np.min(wavelength_array)
            this_maxw = np.max(wavelength_array)
            all_minw = np.min(all_wavelength)
            all_maxw = np.max(all_wavelength)

            if this_minw < all_minw:
                addpts = wavelength_array[wavelength_array < all_minw]
                for ip in range(len(addpts)):
                    all_wavelength.append(addpts[ip])
            if this_maxw > all_maxw:
                addpts = wavelength_array[wavelength_array > all_maxw]
                for ip in range(len(addpts)):
                    all_wavelength.append(addpts[ip])

        # done looping over set of models

        all_ra = np.hstack(all_ra_slit)
        all_dec = np.hstack(all_dec_slit)
        all_wave = np.hstack(all_wavelength)
        all_wave = all_wave[~np.isnan(all_wave)]
        all_wave = np.sort(all_wave, axis=None)
        # Tabular interpolation model, pixels -> lambda
        wavelength_array = np.unique(all_wave)
        # Check if the data is MIRI LRS FIXED Slit. If it is then
        # the wavelength array needs to be flipped so that the resampled
        # dispersion direction matches the disperion direction on the detector.
        if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT':
            wavelength_array = np.flip(wavelength_array, axis=None)

        step = 1 / self.pscale_ratio
        stop = wavelength_array.shape[0] / self.pscale_ratio
        points = np.arange(0, stop, step)
        pix_to_wavelength = Tabular1D(points=points,
                                      lookup_table=wavelength_array,
                                      bounds_error=False,
                                      fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is decending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(0, stop, step)

        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False,
                                              fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_xtan will be nearly zero, so make sure
        # mapping.inverse uses pix_to_ytan.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.

        if np.isclose(pix_to_ytan.slope, 0, atol=1e-8):
            mapping_tuple = (0, 1)
            # Account for vertical or horizontal dispersion on detector
            if spatial_axis:
                mapping.inverse = Mapping(mapping_tuple[::-1])
            else:
                mapping.inverse = Mapping(mapping_tuple)

        # The final transform
        # redefine the ra, dec center tangent point to include all data

        # check if all_ra crosses 0 degress - this makes it hard to
        # define the min and max ra correctly
        all_ra = wrap_ra(all_ra)
        ra_min = np.amin(all_ra)
        ra_max = np.amax(all_ra)

        ra_center_final = (ra_max + ra_min) / 2.0
        dec_min = np.amin(all_dec)
        dec_max = np.amax(all_dec)
        dec_center_final = (dec_max + dec_min) / 2.0
        tan = Pix2Sky_TAN()
        if len(self.input_models
               ) == 1:  # single model use ra_center_pt to be consistent
            # with how resample was done before
            ra_center_final = ra_center_pt
            dec_center_final = dec_center_pt

        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            native2celestial = RotateNative2Celestial(ra_center_final,
                                                      dec_center_final, 180)
            undist2sky = tan | native2celestial
            # find the spatial size of the output - same in x,y
            x_tan_all, _ = undist2sky.inverse(all_ra, all_dec)
        else:
            x_tan_all, _ = all_ra, all_dec
        x_min = np.amin(x_tan_all)
        x_max = np.amax(x_tan_all)
        x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_xtan.slope)))

        # single model use size of x_tan_array
        # to be consistent with method before
        if len(self.input_models) == 1:
            x_size = len(x_tan_array)

        # define the output wcs
        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            transform = mapping | (pix_to_xtan & pix_to_ytan
                                   | undist2sky) & pix_to_wavelength
        else:
            transform = mapping | (pix_to_xtan
                                   & pix_to_ytan) & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            sky = cf.CelestialFrame(name='sky',
                                    axes_order=(0, 1),
                                    reference_frame=coord.ICRS())
        else:
            sky = cf.Frame2D(
                name=f'resampled_{model.meta.wcs.output_frame.name}',
                axes_order=(0, 1))
        spec = cf.SpectralFrame(name='spectral',
                                axes_order=(2, ),
                                unit=(u.micron, ),
                                axes_names=('wavelength', ))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform), (world, None)]

        output_wcs = WCS(pipeline)

        # import ipdb; ipdb.set_trace()

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = int(
            np.ceil(len(wavelength_array) / self.pscale_ratio))
        output_array_size[spatial_axis] = int(
            np.ceil(x_size / self.pscale_ratio))
        # turn the size into a numpy shape in (y, x) order
        self.data_size = tuple(output_array_size[::-1])
        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box

        return output_wcs
Example #11
0
def _make_gwcs_wcs(fits_hdr):
    hdr = fits.Header.fromfile(get_pkg_data_filename(fits_hdr))
    fw = fitswcs.WCS(hdr)

    a_order = hdr['A_ORDER']
    a_coeff = {}
    for i in range(a_order + 1):
        for j in range(a_order + 1 - i):
            key = 'A_{:d}_{:d}'.format(i, j)
            if key in hdr:
                a_coeff[key] = hdr[key]

    b_order = hdr['B_ORDER']
    b_coeff = {}
    for i in range(b_order + 1):
        for j in range(b_order + 1 - i):
            key = 'B_{:d}_{:d}'.format(i, j)
            if key in hdr:
                b_coeff[key] = hdr[key]

    cx = {"c" + k[2:]: v for k, v in a_coeff.items()}
    cy = {"c" + k[2:]: v for k, v in b_coeff.items()}
    sip_distortion = ((Shift(-fw.wcs.crpix[0]) & Shift(-fw.wcs.crpix[1]))
                      | Mapping((0, 1, 0, 1))
                      | (polynomial.Polynomial2D(a_order, **cx, c1_0=1)
                         & polynomial.Polynomial2D(b_order, **cy, c0_1=1))
                      | (Shift(fw.wcs.crpix[0]) & Shift(fw.wcs.crpix[1])))

    y, x = np.indices(fw.array_shape)

    unit_conv = Scale(1.0 / 3600.0, name='arcsec_to_deg_1D')
    unit_conv = unit_conv & unit_conv
    unit_conv.name = 'arcsec_to_deg_2D'

    unit_conv_inv = Scale(3600.0, name='deg_to_arcsec_1D')
    unit_conv_inv = unit_conv_inv & unit_conv_inv
    unit_conv_inv.name = 'deg_to_arcsec_2D'

    c2s = CartesianToSpherical(name='c2s', wrap_lon_at=180)
    s2c = SphericalToCartesian(name='s2c', wrap_lon_at=180)
    c2tan = ((Mapping((0, 1, 2), name='xyz') / Mapping(
        (0, 0, 0), n_inputs=3, name='xxx')) | Mapping((1, 2), name='xtyt'))
    c2tan.name = 'Cartesian 3D to TAN'

    tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') |
             (Const1D(1, name='one') & Identity(2, name='I(2D)')))
    tan2c.name = 'TAN to cartesian 3D'

    tan2c.inverse = c2tan
    c2tan.inverse = tan2c

    aff = AffineTransformation2D(matrix=fw.wcs.cd)

    offx = Shift(-fw.wcs.crpix[0])
    offy = Shift(-fw.wcs.crpix[1])

    s = 5e-6
    scale = Scale(s) & Scale(s)

    sip_distortion |= (offx & offy) | scale | tan2c | c2s | unit_conv_inv

    taninv = s2c | c2tan
    tan = Pix2Sky_TAN()
    n2c = RotateNative2Celestial(fw.wcs.crval[0], fw.wcs.crval[1], 180)
    wcslin = unit_conv | taninv | scale.inverse | aff | tan | n2c

    sky_frm = cf.CelestialFrame(reference_frame=coord.ICRS())
    det_frm = cf.Frame2D(name='detector')
    v2v3_frm = cf.Frame2D(name="v2v3",
                          unit=(u.arcsec, u.arcsec),
                          axes_names=('x', 'y'),
                          axes_order=(0, 1))
    pipeline = [(det_frm, sip_distortion), (v2v3_frm, wcslin), (sky_frm, None)]

    gw = gwcs.WCS(input_frame=det_frm,
                  output_frame=sky_frm,
                  forward_transform=pipeline)
    gw.crpix = fw.wcs.crpix
    gw.crval = fw.wcs.crval
    gw.bounding_box = ((-0.5, fw.pixel_shape[0] - 0.5),
                       (-0.5, fw.pixel_shape[1] - 0.5))

    # sanity check:
    for _ in range(100):
        x = np.random.randint(1, fw.pixel_shape[0])
        y = np.random.randint(1, fw.pixel_shape[1])
        assert np.allclose(gw(x, y),
                           fw.all_pix2world(x, y, 1),
                           rtol=0,
                           atol=1e-11)

    return gw
Example #12
0
def test_uses_quantity_no_param():
    comp = Mapping((0, 1)) | Pix2Sky_TAN()

    assert comp.uses_quantity
Example #13
0
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """

        if refmodel is None:
            refmodel = self.input_models[0]

        refwcs = refmodel.meta.wcs
        bb = refwcs.bounding_box
        grid = wcstools.grid_from_bounding_box(bb)
        ra, dec, lam = np.array(refwcs(*grid))
        lon = np.nanmean(ra)
        lat = np.nanmean(dec)
        tan = Pix2Sky_TAN()
        native2celestial = RotateNative2Celestial(lon, lat, 180)
        undist2sky = tan | native2celestial
        # Filter out RuntimeWarnings due to computed NaNs in the WCS
        warnings.simplefilter("ignore")
        x_tan, y_tan = undist2sky.inverse(ra, dec)
        warnings.resetwarnings()

        spectral_axis = find_dispersion_axis(refmodel)
        spatial_axis = spectral_axis ^ 1

        # Compute the wavelength array, trimming NaNs from the ends
        wavelength_array = np.nanmedian(lam, axis=spectral_axis)
        wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

        # Compute RA and Dec up the slit (spatial direction) at the center
        # of the dispersion.  Use spectral_axis to determine slicing dimension
        lam_center_index = int(
            (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2)
        if not spectral_axis:
            x_tan_array = x_tan.T[lam_center_index]
            y_tan_array = y_tan.T[lam_center_index]
        else:
            x_tan_array = x_tan[lam_center_index]
            y_tan_array = y_tan[lam_center_index]
        x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
        y_tan_array = y_tan_array[~np.isnan(y_tan_array)]

        fitter = LinearLSQFitter()
        fit_model = Linear1D()
        pix_to_ra = fitter(fit_model, np.arange(x_tan_array.shape[0]),
                           x_tan_array)
        pix_to_dec = fitter(fit_model, np.arange(y_tan_array.shape[0]),
                            y_tan_array)

        # Tabular interpolation model, pixels -> lambda
        pix_to_wavelength = Tabular1D(lookup_table=wavelength_array,
                                      bounds_error=False,
                                      fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is decending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(wavelength_array.shape[0])
        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False,
                                              fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_ra will be nearly zero, so make sure
        # mapping.inverse uses pix_to_dec.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.
        if np.isclose(pix_to_dec.slope, 0, atol=1e-8):
            mapping_tuple = (0, 1)
            # Account for vertical or horizontal dispersion on detector
            if spatial_axis:
                mapping.inverse = Mapping(mapping_tuple[::-1])
            else:
                mapping.inverse = Mapping(mapping_tuple)

        # The final transform
        transform = mapping | (pix_to_ra & pix_to_dec
                               | undist2sky) & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.CelestialFrame(name='sky',
                                axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        spec = cf.SpectralFrame(name='spectral',
                                axes_order=(2, ),
                                unit=(u.micron, ),
                                axes_names=('wavelength', ))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform), (world, None)]

        output_wcs = WCS(pipeline)

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = len(wavelength_array)
        output_array_size[spatial_axis] = len(x_tan_array)

        # turn the size into a numpy shape in (y, x) order
        self.data_size = tuple(output_array_size[::-1])
        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box

        return output_wcs