예제 #1
0
def wcs_from_spec_footprints(wcslist, refwcs=None, transform=None, domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError("All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")
    
    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame, input_frame=input_frame,
        forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({'lower': axis_min, 'upper': axis_max,
            'includes_lower': True, 'includes_upper': False})
    wnew.domain = domain
    return wnew
예제 #2
0
def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray]) -> float:
    """Compute scaling transform.

    Parameters
    ----------
    wcs : `~gwcs.wcs.WCS`
        Reference WCS object from which to compute a scaling factor.

    fiducial : tuple
        Input fiducial of (RA, DEC) used in calculating reference points.

    Returns
    -------
    scale : float
        Scaling factor for x and y.

    """
    if len(fiducial) != 2:
        raise ValueError(f'Input fiducial must contain only (RA, DEC); Instead recieved: {fiducial}')

    crpix = np.array(wcs.invert(*fiducial))
    crpix_with_offsets = np.vstack((crpix, crpix + (1, 0), crpix + (0, 1))).T
    crval_with_offsets = wcs(*crpix_with_offsets)

    coords = SkyCoord(ra=crval_with_offsets[0], dec=crval_with_offsets[1], unit="deg")
    xscale = np.abs(coords[0].separation(coords[1]).value)
    yscale = np.abs(coords[0].separation(coords[2]).value)

    return np.sqrt(xscale * yscale)
예제 #3
0
파일: util.py 프로젝트: spacetelescope/jwst
def compute_scale(wcs: WCS,
                  fiducial: Union[tuple, np.ndarray],
                  disp_axis: int = None,
                  pscale_ratio: float = None) -> float:
    """Compute scaling transform.

    Parameters
    ----------
    wcs : `~gwcs.wcs.WCS`
        Reference WCS object from which to compute a scaling factor.

    fiducial : tuple
        Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points.

    disp_axis : int
        Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction`

    pscale_ratio : int
        Ratio of input to output pixel scale

    Returns
    -------
    scale : float
        Scaling factor for x and y or cross-dispersion direction.

    """
    spectral = 'SPECTRAL' in wcs.output_frame.axes_type

    if spectral and disp_axis is None:
        raise ValueError('If input WCS is spectral, a disp_axis must be given')

    crpix = np.array(wcs.invert(*fiducial))

    delta = np.zeros_like(crpix)
    spatial_idx = np.where(
        np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0]
    delta[spatial_idx[0]] = 1

    crpix_with_offsets = np.vstack(
        (crpix, crpix + delta, crpix + np.roll(delta, 1))).T
    crval_with_offsets = wcs(*crpix_with_offsets)

    coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]],
                      dec=crval_with_offsets[spatial_idx[1]],
                      unit="deg")
    xscale = np.abs(coords[0].separation(coords[1]).value)
    yscale = np.abs(coords[0].separation(coords[2]).value)

    if pscale_ratio is not None:
        xscale = xscale * pscale_ratio
        yscale = yscale * pscale_ratio

    if spectral:
        # Assuming scale doesn't change with wavelength
        # Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
        return yscale if disp_axis == 1 else xscale

    return np.sqrt(xscale * yscale)
예제 #4
0
def gwcs_1d():
    detector_frame = cf.CoordinateFrame(
        name="detector",
        naxes=1,
        axes_order=(0, ),
        axes_type=("pixel"),
        axes_names=("x"),
        unit=(u.pix))

    spec_frame = cf.SpectralFrame(name="spectral", axes_order=(2, ), unit=u.nm)

    return WCS(forward_transform=Identity(1), input_frame=detector_frame, output_frame=spec_frame)
예제 #5
0
def assign_moving_target_wcs(input_model):

    if not isinstance(input_model, datamodels.ModelContainer):
        raise ValueError("Expected a ModelContainer object")

    # Get the MT RA/Dec values from all the input exposures
    mt_ra = np.array(
        [model.meta.wcsinfo.mt_ra for model in input_model._models])
    mt_dec = np.array(
        [model.meta.wcsinfo.mt_dec for model in input_model._models])

    # Compute the mean MT RA/Dec over all exposures
    if (None in mt_ra) or (None in mt_dec):
        log.warning("One or more MT RA/Dec values missing in input images")
        log.warning("Step will be skipped, resulting in target misalignment")
        for model in input_model:
            model.meta.cal_step.assign_mtwcs = 'SKIPPED'
        return input_model
    else:
        mt_avra = mt_ra.mean()
        mt_avdec = mt_dec.mean()

    for model in input_model:
        pipeline = model.meta.wcs._pipeline[:-1]

        mt = deepcopy(model.meta.wcs.output_frame)
        mt.name = 'moving_target'

        mt_ra = model.meta.wcsinfo.mt_ra
        mt_dec = model.meta.wcsinfo.mt_dec
        model.meta.wcsinfo.mt_avra = mt_avra
        model.meta.wcsinfo.mt_avdec = mt_avdec

        rdel = mt_avra - mt_ra
        ddel = mt_avdec - mt_dec

        if isinstance(mt, cf.CelestialFrame):
            transform_to_mt = Shift(rdel) & Shift(ddel)
        elif isinstance(mt, cf.CompositeFrame):
            transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1)
        else:
            raise ValueError("Unrecognized coordinate frame.")

        pipeline.append((model.meta.wcs.output_frame, transform_to_mt))
        pipeline.append((mt, None))
        new_wcs = WCS(pipeline)
        del model.meta.wcs
        model.meta.wcs = new_wcs
        model.meta.cal_step.assign_mtwcs = 'COMPLETE'

    return input_model
예제 #6
0
def wcs_from_spec_footprints(wcslist,
                             refwcs=None,
                             transform=None,
                             domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError(
            "All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")

    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame,
               input_frame=input_frame,
               forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({
            'lower': axis_min,
            'upper': axis_max,
            'includes_lower': True,
            'includes_upper': False
        })
    wnew.domain = domain
    return wnew
예제 #7
0
def gwcs_3d():
    detector_frame = cf.CoordinateFrame(
        name="detector",
        naxes=3,
        axes_order=(0, 1, 2),
        axes_type=("pixel", "pixel", "pixel"),
        axes_names=("x", "y", "z"),
        unit=(u.pix, u.pix, u.pix))

    sky_frame = cf.CelestialFrame(reference_frame=Helioprojective(), name='hpc')
    spec_frame = cf.SpectralFrame(name="spectral", axes_order=(2, ), unit=u.nm)
    out_frame = cf.CompositeFrame(frames=(sky_frame, spec_frame))

    return WCS(forward_transform=spatial_like_model(),
               input_frame=detector_frame, output_frame=out_frame)
예제 #8
0
def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec):
    """ Add a "moving_target" frame to the WCS pipeline.

    Parameters
    ----------
    wcs : `~gwcs.WCS`
        WCS object for the observation or slit.
    ra_average : float
        The average RA of all observations.
    dec_average : float
        The average DEC of all observations.
    mt_ra, mt_dec : float
        The RA, DEC of the moving target in the observation.

    Returns
    -------
    new_wcs : `~gwcs.WCS`
        The WCS for the moving target observation.
    """
    pipeline = wcs._pipeline[:-1]

    mt = deepcopy(wcs.output_frame)
    mt.name = 'moving_target'

    rdel = ra_average - mt_ra
    ddel = dec_average - mt_dec

    if isinstance(mt, cf.CelestialFrame):
        transform_to_mt = Shift(rdel) & Shift(ddel)
    elif isinstance(mt, cf.CompositeFrame):
        transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1)
    else:
        raise ValueError("Unrecognized coordinate frame.")

    pipeline.append((
        wcs.output_frame, transform_to_mt))
    pipeline.append((mt, None))
    new_wcs = WCS(pipeline)
    return new_wcs
예제 #9
0
def assign_moving_target_wcs(input_model):

    if not isinstance(input_model, datamodels.ModelContainer):
        raise ValueError("Expected a ModelContainer object")

    mt_ra = np.array(
        [model.meta.wcsinfo.mt_ra for model in input_model._models])
    mt_dec = np.array(
        [model.meta.wcsinfo.mt_dec for model in input_model._models])
    mt_avra = mt_ra.mean()
    mt_avdec = mt_dec.mean()

    for model in input_model:
        pipeline = model.meta.wcs._pipeline[:-1]

        mt = deepcopy(model.meta.wcs.output_frame)
        mt.name = 'moving_target'

        mt_ra = model.meta.wcsinfo.mt_ra
        mt_dec = model.meta.wcsinfo.mt_dec
        model.meta.wcsinfo.mt_avra = mt_avra
        model.meta.wcsinfo.mt_avdec = mt_avdec

        rdel = mt_avra - mt_ra
        ddel = mt_avdec - mt_dec

        if isinstance(mt, cf.CelestialFrame):
            transform_to_mt = Shift(rdel) & Shift(ddel)
        elif isinstance(mt, cf.CompositeFrame):
            transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1)
        else:
            raise ValueError("Unrecognized coordinate frame.")
        pipeline.append((model.meta.wcs.output_frame, transform_to_mt))
        pipeline.append((mt, None))
        new_wcs = WCS(pipeline)
        del model.meta.wcs
        model.meta.wcs = new_wcs
        model.meta.cal_step.assign_mtwcs = 'COMPLETE'
    return input_model
예제 #10
0
    def build_nirspec_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS covering footprint of the input
        """
        all_wcs = [m.meta.wcs for m in self.input_models if m is not refmodel]
        if refmodel:
            all_wcs.insert(0, refmodel.meta.wcs)
        else:
            refmodel = self.input_models[0]

        refwcs = refmodel.meta.wcs

        s2d = refwcs.get_transform('slit_frame', 'detector')
        d2s = refwcs.get_transform('detector', 'slit_frame')
        s2w = refwcs.get_transform('slit_frame', 'world')

        # estimate position of the target without relying in the meta.target:
        bbox = refwcs.bounding_box

        grid = wcstools.grid_from_bounding_box(bbox)
        _, s, lam = np.array(d2s(*grid))
        sd = s * refmodel.data
        ld = lam * refmodel.data
        good_s = np.isfinite(sd)
        if np.any(good_s):
            total = np.sum(refmodel.data[good_s])
            wmean_s = np.sum(sd[good_s]) / total
            wmean_l = np.sum(ld[good_s]) / total
        else:
            wmean_s = 0.5 * (refmodel.slit_ymax - refmodel.slit_ymin)
            wmean_l = d2s(*np.mean(bbox, axis=1))[2]

        targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l)

        ref_lam = _find_nirspec_output_sampling_wavelengths(
            all_wcs,
            targ_ra, targ_dec
        )
        ref_lam = np.array(ref_lam)

        n_lam = ref_lam.size
        if not n_lam:
            raise ValueError("Not enough data to construct output WCS.")

        x_slit = np.zeros(n_lam)
        lam = 1e-6 * ref_lam

        # Find the spatial pixel scale:
        y_slit_min, y_slit_max = self._max_virtual_slit_extent(all_wcs, targ_ra, targ_dec)

        nsampl = 50
        xy_min = s2d(
            nsampl * [0],
            nsampl * [y_slit_min],
            lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )]
        )
        xy_max = s2d(
            nsampl * [0],
            nsampl * [y_slit_max],
            lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )]
        )

        good = np.logical_and(np.isfinite(xy_min), np.isfinite(xy_max))
        if not np.any(good):
            raise ValueError("Error estimating output WCS pixel scale.")

        xy1 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymin), lam)
        xy2 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymax), lam)
        xylen = np.nanmax(np.linalg.norm(np.array(xy1) - np.array(xy2), axis=0)) + 1
        pscale = (refmodel.slit_ymax - refmodel.slit_ymin) / xylen

        # compute image span along Y-axis (length of the slit in the detector plane)
        # det_slit_span = np.linalg.norm(np.subtract(xy_max, xy_min))
        det_slit_span = np.nanmax(np.linalg.norm(np.subtract(xy_max, xy_min), axis=0))
        ny = int(np.ceil(det_slit_span * self.pscale_ratio + 0.5)) + 1

        border = 0.5 * (ny - det_slit_span * self.pscale_ratio) - 0.5

        if xy_min[1][1] < xy_max[1][1]:
            y_slit_model = Linear1D(
                slope=pscale / self.pscale_ratio,
                intercept=y_slit_min - border * pscale * self.pscale_ratio
            )
        else:
            y_slit_model = Linear1D(
                slope=-pscale / self.pscale_ratio,
                intercept=y_slit_max + border * pscale * self.pscale_ratio
            )

        # extrapolate 1/2 pixel at the edges and make tabular model w/inverse:
        lam = lam.tolist()
        pixel_coord = list(range(n_lam))

        if len(pixel_coord) > 1:
            # left:
            slope = (lam[1] - lam[0]) / pixel_coord[1]
            lam.insert(0, -0.5 * slope + lam[0])
            pixel_coord.insert(0, -0.5)
            # right:
            slope = (lam[-1] - lam[-2]) / (pixel_coord[-1] - pixel_coord[-2])
            lam.append(slope * (pixel_coord[-1] + 0.5) + lam[-2])
            pixel_coord.append(pixel_coord[-1] + 0.5)

        else:
            lam = 3 * lam
            pixel_coord = [-0.5, 0, 0.5]

        wavelength_transform = Tabular1D(points=pixel_coord,
                                         lookup_table=lam,
                                         bounds_error=False, fill_value=np.nan)
        wavelength_transform.inverse = Tabular1D(points=lam,
                                                 lookup_table=pixel_coord,
                                                 bounds_error=False,
                                                 fill_value=np.nan)
        self.data_size = (ny, len(ref_lam))

        # Construct the final transform
        mapping = Mapping((0, 1, 0))
        mapping.inverse = Mapping((2, 1))
        out_det2slit = mapping | Identity(1) & y_slit_model & wavelength_transform

        # Create coordinate frames
        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1),
                                  unit=("", ""), axes_names=('x_slit', 'y_slit'))
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
                                unit=(u.micron,), axes_names=('wavelength',))
        slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame')
        sky = cf.CelestialFrame(name='sky', axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, out_det2slit), (slit_frame, s2w), (world, None)]
        output_wcs = WCS(pipeline)

        # Compute bounding box and output array shape.  Add one to the y (slit)
        # height to account for the half pixel at top and bottom due to pixel
        # coordinates being centers of pixels
        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box
        output_wcs.array_shape = self.data_size

        return output_wcs
예제 #11
0
    def build_nirspec_lamp_output_wcs(self):
        """
        Create a spatial/spectral WCS output frame for NIRSpec lamp mode

        Creates output frame by linearly fitting x_msa, y_msa along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """
        model = self.input_models[0]
        wcs = model.meta.wcs
        bbox = wcs.bounding_box
        grid = wcstools.grid_from_bounding_box(bbox)
        x_msa, y_msa, lam = np.array(wcs(*grid))
        # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion.  The
        # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order
        spectral_axis = find_dispersion_axis(model)
        spatial_axis = spectral_axis ^ 1

        # Compute the wavelength array, trimming NaNs from the ends
        # In many cases, a whole slice is NaNs, so ignore those warnings
        warnings.simplefilter("ignore")
        wavelength_array = np.nanmedian(lam, axis=spectral_axis)
        warnings.resetwarnings()
        wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

        # Find the center ra and dec for this slit at central wavelength
        lam_center_index = int((bbox[spectral_axis][1] -
                                bbox[spectral_axis][0]) / 2)
        x_msa_array = x_msa.T[lam_center_index]
        y_msa_array = y_msa.T[lam_center_index]
        x_msa_array = x_msa_array[~np.isnan(x_msa_array)]
        y_msa_array = y_msa_array[~np.isnan(y_msa_array)]

        # Estimate and fit the spatial sampling
        fitter = LinearLSQFitter()
        fit_model = Linear1D()
        xstop = x_msa_array.shape[0] / self.pscale_ratio
        xstep = 1 / self.pscale_ratio
        ystop = y_msa_array.shape[0] / self.pscale_ratio
        ystep = 1 / self.pscale_ratio
        pix_to_x_msa = fitter(fit_model, np.arange(0, xstop, xstep), x_msa_array)
        pix_to_y_msa = fitter(fit_model, np.arange(0, ystop, ystep), y_msa_array)

        step = 1 / self.pscale_ratio
        stop = wavelength_array.shape[0] / self.pscale_ratio
        points = np.arange(0, stop, step)
        pix_to_wavelength = Tabular1D(points=points,
                                      lookup_table=wavelength_array,
                                      bounds_error=False, fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is descending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(0, stop, step)

        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False, fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))
        mapping.inverse = Mapping((2, 1))

        # The final transform
        # define the output wcs
        transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.Frame2D(name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1))
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
                                unit=(u.micron,), axes_names=('wavelength',))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform),
                    (world, None)]

        output_wcs = WCS(pipeline)

        # Compute the output array size and bounding box
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio))
        x_size = len(x_msa_array)
        output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio))
        # turn the size into a numpy shape in (y, x) order
        output_wcs.array_shape = output_array_size[::-1]
        output_wcs.pixel_shape = output_array_size
        bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1])
        output_wcs.bounding_box = bounding_box

        return output_wcs
예제 #12
0
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """
        if refmodel is None:
            refmodel = self.input_models[0]
        refwcs = refmodel.meta.wcs
        bb = refwcs.bounding_box

        grid = wcstools.grid_from_bounding_box(bb)
        ra, dec, lam = np.array(refwcs(*grid))

        spectral_axis = find_dispersion_axis(lam)
        spatial_axis = spectral_axis ^ 1

        # Compute the wavelength array, trimming NaNs from the ends
        wavelength_array = np.nanmedian(lam, axis=spectral_axis)
        wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

        # Compute RA and Dec up the slit (spatial direction) at the center
        # of the dispersion.  Use spectral_axis to determine slicing dimension
        lam_center_index = int((bb[spectral_axis][1] - bb[spectral_axis][0]) / 2)
        if not spectral_axis:
            ra_array = ra.T[lam_center_index]
            dec_array = dec.T[lam_center_index]
        else:
            ra_array = ra[lam_center_index]
            dec_array = dec[lam_center_index]
        ra_array = ra_array[~np.isnan(ra_array)]
        dec_array = dec_array[~np.isnan(dec_array)]

        fitter = LinearLSQFitter()
        fit_model = Linear1D()
        pix_to_ra = fitter(fit_model, np.arange(ra_array.shape[0]), ra_array)
        pix_to_dec = fitter(fit_model, np.arange(dec_array.shape[0]), dec_array)

        # Tabular interpolation model, pixels -> lambda
        pix_to_wavelength = Tabular1D(lookup_table=wavelength_array,
            bounds_error=False, fill_value=None, name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is decending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(wavelength_array.shape[0])
        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
            lookup_table=lookup_table,
            bounds_error=False, fill_value=None, name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_ra will be nearly zero, so make sure
        # mapping.inverse uses pix_to_dec.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.
        if np.isclose(pix_to_dec.slope, 0, atol=1e-8):
            mapping_tuple = (0, 1)
            # Account for vertical or horizontal dispersion on detector
            if spatial_axis:
                mapping.inverse = Mapping(mapping_tuple[::-1])
            else:
                mapping.inverse = Mapping(mapping_tuple)

        # The final transform
        transform = mapping | pix_to_ra & pix_to_dec & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.CelestialFrame(name='sky', axes_order=(0, 1),
            reference_frame=coord.ICRS())
        spec = cf.SpectralFrame(name='spectral', axes_order=(2,),
            unit=(u.micron,), axes_names=('wavelength',))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform),
                    (world, None)]

        output_wcs = WCS(pipeline)

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = len(wavelength_array)
        output_array_size[spatial_axis] = len(ra_array)

        # turn the size into a numpy shape in (y, x) order
        self.data_size = tuple(output_array_size[::-1])

        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box

        return output_wcs
예제 #13
0
파일: resample_spec.py 프로젝트: zonca/jwst
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """
        if refmodel is None:
            refmodel = self.input_models[0]
        refwcs = refmodel.meta.wcs
        bb = refwcs.bounding_box

        grid = wcstools.grid_from_bounding_box(bb)
        ra, dec, lam = np.array(refwcs(*grid))

        spectral_axis = find_dispersion_axis(lam)
        spatial_axis = spectral_axis ^ 1

        # Compute the wavelength array, trimming NaNs from the ends
        wavelength_array = np.nanmedian(lam, axis=spectral_axis)
        wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

        # Compute RA and Dec up the slit (spatial direction) at the center
        # of the dispersion.  Use spectral_axis to determine slicing dimension
        lam_center_index = int(
            (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2)
        if not spectral_axis:
            ra_array = ra.T[lam_center_index]
            dec_array = dec.T[lam_center_index]
        else:
            ra_array = ra[lam_center_index]
            dec_array = dec[lam_center_index]
        ra_array = ra_array[~np.isnan(ra_array)]
        dec_array = dec_array[~np.isnan(dec_array)]

        fitter = LinearLSQFitter()
        fit_model = Linear1D()
        pix_to_ra = fitter(fit_model, np.arange(ra_array.shape[0]), ra_array)
        pix_to_dec = fitter(fit_model, np.arange(dec_array.shape[0]),
                            dec_array)

        # Tabular interpolation model, pixels -> lambda
        pix_to_wavelength = Tabular1D(lookup_table=wavelength_array,
                                      bounds_error=False,
                                      fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is decending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(wavelength_array.shape[0])
        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False,
                                              fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_ra will be nearly zero, so make sure
        # mapping.inverse uses pix_to_dec.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.
        if np.isclose(pix_to_dec.slope, 0, atol=1e-8):
            mapping_tuple = (0, 1)
            # Account for vertical or horizontal dispersion on detector
            if spatial_axis:
                mapping.inverse = Mapping(mapping_tuple[::-1])
            else:
                mapping.inverse = Mapping(mapping_tuple)

        # The final transform
        transform = mapping | pix_to_ra & pix_to_dec & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        sky = cf.CelestialFrame(name='sky',
                                axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        spec = cf.SpectralFrame(name='spectral',
                                axes_order=(2, ),
                                unit=(u.micron, ),
                                axes_names=('wavelength', ))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform), (world, None)]

        output_wcs = WCS(pipeline)

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = len(wavelength_array)
        output_array_size[spatial_axis] = len(ra_array)

        # turn the size into a numpy shape in (y, x) order
        self.data_size = tuple(output_array_size[::-1])

        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box

        return output_wcs
예제 #14
0
    def build_nirspec_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS covering footprint of the input
        """
        if not refmodel:
            refmodel = self.input_models[0]
        refwcs = refmodel.meta.wcs
        bb = refwcs.bounding_box

        ref_det2slit = refwcs.get_transform('detector', 'slit_frame')
        ref_slit2world = refwcs.get_transform('slit_frame', 'world')

        grid = x, y = wcstools.grid_from_bounding_box(bb, step=(1, 1))
        Grid = namedtuple('Grid', refwcs.slit_frame.axes_names)
        grid_slit = Grid(*ref_det2slit(*grid))

        # Compute spatial transform from detector to slit
        ref_wavelength = np.nanmean(grid_slit.wavelength)

        # find the number of pixels sampled by a single shutter
        fid = np.array([[0., 0.], [-.5, .5], np.repeat(ref_wavelength, 2)])
        slit_extents = np.array(ref_det2slit.inverse(*fid)).T
        pix_per_shutter = np.linalg.norm(slit_extents[0] - slit_extents[1])

        # Get min and max of slit in pixel units
        ymin = np.nanmin(grid_slit.y_slit) * pix_per_shutter
        ymax = np.nanmax(grid_slit.y_slit) * pix_per_shutter
        slit_height_pix = int(abs(ymax - ymin + 0.5))

        # Compute grid of wavelengths and make tabular model w/inverse
        lookup_table = np.nanmean(grid_slit.wavelength, axis=0)
        wavelength_transform = Tabular1D(lookup_table=lookup_table,
                                         bounds_error=False,
                                         fill_value=np.nan)
        wavelength_transform.inverse = Tabular1D(
            points=lookup_table,
            lookup_table=np.arange(grid_slit.wavelength.shape[1]),
            bounds_error=False,
            fill_value=np.nan)

        # Define detector to slit transforms
        yslit_transform = Scale(-1 / pix_per_shutter) | Shift(
            ymax / pix_per_shutter)
        xslit_transform = Const1D(-0.)
        xslit_transform.inverse = Const1D(0.)

        # Construct the final transform
        coord_mapping = Mapping((0, 1, 0))
        coord_mapping.inverse = Mapping((2, 1))
        the_transform = xslit_transform & yslit_transform & wavelength_transform
        out_det2slit = coord_mapping | the_transform

        # Create coordinate frames
        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        slit_spatial = cf.Frame2D(name='slit_spatial',
                                  axes_order=(0, 1),
                                  unit=("", ""),
                                  axes_names=('x_slit', 'y_slit'))
        spec = cf.SpectralFrame(name='spectral',
                                axes_order=(2, ),
                                unit=(u.micron, ),
                                axes_names=('wavelength', ))
        slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame')
        sky = cf.CelestialFrame(name='sky',
                                axes_order=(0, 1),
                                reference_frame=coord.ICRS())
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, out_det2slit), (slit_frame, ref_slit2world),
                    (world, None)]

        output_wcs = WCS(pipeline)

        self.data_size = (slit_height_pix, len(lookup_table))
        bounding_box = resample_utils.bounding_box_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box
        return output_wcs
예제 #15
0
파일: util.py 프로젝트: mperrin/jwst
def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None):
    """
    Create a WCS from a list of input data models.

    A fiducial point in the output coordinate frame is created from  the
    footprints of all WCS objects. For a spatial frame this is the center
    of the union of the footprints. For a spectral frame the fiducial is in
    the beginning of the footprint range.
    If ``refmodel`` is None, the first WCS object in the list is considered
    a reference. The output coordinate frame and projection (for celestial frames)
    is taken from ``refmodel``.
    If ``transform`` is not suplied, a compound transform is created using
    CDELTs and PC.
    If ``bounding_box`` is not supplied, the bounding_box of the new WCS is computed
    from bounding_box of all input WCSs.

    Parameters
    ----------
    dmodels : list of `~jwst.datamodels.DataModel`
        A list of data models.
    refmodel : `~jwst.datamodels.DataModel`, optional
        This model's WCS is used as a reference.
        WCS. The output coordinate frame, the projection and a
        scaling and rotation transform is created from it. If not supplied
        the first model in the list is used as ``refmodel``.
    transform : `~astropy.modeling.core.Model`, optional
        A transform, passed to :meth:`~gwcs.wcstools.wcs_from_fiducial`
        If not supplied Scaling | Rotation is computed from ``refmodel``.
    bounding_box : tuple, optional
        Bounding_box of the new WCS.
        If not supplied it is computed from the bounding_box of all inputs.
    """
    bb = bounding_box
    wcslist = [im.meta.wcs for im in dmodels]

    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.")

    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError("All items in wcslist are to be instances of gwcs.WCS.")

    if refmodel is None:
        refmodel = dmodels[0]
    else:
        if not isinstance(refmodel, DataModel):
            raise TypeError("Expected refmodel to be an instance of DataModel.")

    fiducial = compute_fiducial(wcslist, bb)
    ref_fiducial = compute_fiducial([refmodel.meta.wcs])

    prj = astmodels.Pix2Sky_TAN()


    if transform is None:
        transform = []
        wcsinfo = pointing.wcsinfo_from_model(refmodel)
        sky_axes, spec, other = gwutils.get_axes(wcsinfo)

        # Need to put the rotation matrix (List[float, float, float, float]) returned from calc_rotation_matrix into the
        # correct shape for constructing the transformation
        pc = np.reshape(
            calc_rotation_matrix(
                np.deg2rad(refmodel.meta.wcsinfo.roll_ref),
                np.deg2rad(refmodel.meta.wcsinfo.v3yangle),
                vparity=refmodel.meta.wcsinfo.vparity),
            (2, 2)
        )

        rotation = astmodels.AffineTransformation2D(pc)
        transform.append(rotation)

        if sky_axes:
            scale = compute_scale(refmodel.meta.wcs, ref_fiducial)
            transform.append(astmodels.Scale(scale) & astmodels.Scale(scale))

        if transform:
            transform = functools.reduce(lambda x, y: x | y, transform)

    out_frame = refmodel.meta.wcs.output_frame
    input_frame = dmodels[0].meta.wcs.input_frame
    wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, transform=transform)
    # temporary fix before gwcs 0.14 is released
    # wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj,
    #                          transform=transform, input_frame=input_frame)
    pipe = wnew.pipeline[:]
    pipe[0] = (input_frame, pipe[0][1])
    wnew = WCS(pipe)

    footprints = [w.footprint().T for w in wcslist]
    domain_bounds = np.hstack([wnew.backward_transform(*f) for f in footprints])

    for axs in domain_bounds:
        axs -= (axs.min()  + .5)

    bounding_box = []
    for axis in out_frame.axes_order:
        axis_min, axis_max = domain_bounds[axis].min(), domain_bounds[axis].max()
        bounding_box.append((axis_min, axis_max))

    bounding_box = tuple(bounding_box)
    ax1, ax2 = np.array(bounding_box)[sky_axes]
    offset1 = (ax1[1] - ax1[0]) / 2
    offset2 = (ax2[1] - ax2[0]) / 2
    offsets = astmodels.Shift(-offset1) & astmodels.Shift(-offset2)

    wnew.insert_transform('detector', offsets, after=True)
    wnew.bounding_box = bounding_box

    return wnew
예제 #16
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within bounding box
        bb = refwcs.bounding_box
        det = x, y = wcstools.grid_from_bounding_box(bb,
                                                     step=(1, 1),
                                                     center=True)
        sky = ra, dec, lam = refwcs(*det)
        x_center = int((bb[0][1] - bb[0][0]) / 2)
        y_center = int((bb[1][1] - bb[1][0]) / 2)
        log.debug("Center of bounding box: {}  {}".format(x_center, y_center))

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                                         x[y_center,
                                           x_center - sz + 1:x_center + sz],
                                         bounds_error=False,
                                         fill_value='extrapolate')
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:, x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
                                   dec=interpol_dec(slit_center_pix),
                                   unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(
            resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        slit_rot_angle = slit_rot_angle.value
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)

        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale.value) & Scale(spatial_scale.value)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra.value,
                                        slit_center_sky.dec.value,
                                        lon_pole.value)

        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)
        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
예제 #17
0
def test_round_trip_gwcs(tmpdir):
    """
    Add a 2-step gWCS instance to NDAstroData, save to disk, reload & compare.
    """

    from gwcs import coordinate_frames as cf
    from gwcs import WCS

    arr = np.zeros((10, 10), dtype=np.float32)
    ad1 = astrodata.create(fits.PrimaryHDU(), [fits.ImageHDU(arr, name='SCI')])

    # Transformation from detector pixels to pixels in some reference row,
    # removing relative distortions in wavelength:
    det_frame = cf.Frame2D(name='det_mosaic',
                           axes_names=('x', 'y'),
                           unit=(u.pix, u.pix))
    dref_frame = cf.Frame2D(name='dist_ref_row',
                            axes_names=('xref', 'y'),
                            unit=(u.pix, u.pix))

    # A made-up example model that looks vaguely like some real distortions:
    fdist = models.Chebyshev2D(2,
                               2,
                               c0_0=4.81125,
                               c1_0=5.43375,
                               c0_1=-0.135,
                               c1_1=-0.405,
                               c0_2=0.30375,
                               c1_2=0.91125,
                               x_domain=[0., 9.],
                               y_domain=[0., 9.])

    # This is not an accurate inverse, but will do for this test:
    idist = models.Chebyshev2D(2,
                               2,
                               c0_0=4.89062675,
                               c1_0=5.68581232,
                               c2_0=-0.00590263,
                               c0_1=0.11755526,
                               c1_1=0.35652358,
                               c2_1=-0.01193828,
                               c0_2=-0.29996306,
                               c1_2=-0.91823397,
                               c2_2=0.02390594,
                               x_domain=[-1.5, 12.],
                               y_domain=[0., 9.])

    # The resulting 2D co-ordinate mapping from detector to ref row pixels:
    distrans = models.Mapping((0, 1, 1)) | (fdist & models.Identity(1))
    distrans.inverse = models.Mapping((0, 1, 1)) | (idist & models.Identity(1))

    # Transformation from reference row pixels to linear, row-stacked spectra:
    spec_frame = cf.SpectralFrame(axes_order=(0, ),
                                  unit=u.nm,
                                  axes_names='lambda',
                                  name='wavelength')
    row_frame = cf.CoordinateFrame(1,
                                   'SPATIAL',
                                   axes_order=(1, ),
                                   unit=u.pix,
                                   axes_names='y',
                                   name='row')
    rss_frame = cf.CompositeFrame([spec_frame, row_frame])

    # Toy wavelength model & approximate inverse:
    fwcal = models.Chebyshev1D(2, c0=500.075, c1=0.05, c2=0.001, domain=[0, 9])
    iwcal = models.Chebyshev1D(2,
                               c0=4.59006292,
                               c1=4.49601817,
                               c2=-0.08989608,
                               domain=[500.026, 500.126])

    # The resulting 2D co-ordinate mapping from ref pixels to wavelength:
    wavtrans = fwcal & models.Identity(1)
    wavtrans.inverse = iwcal & models.Identity(1)

    # The complete WCS chain for these 2 transformation steps:
    ad1[0].nddata.wcs = WCS([(det_frame, distrans), (dref_frame, wavtrans),
                             (rss_frame, None)])

    # Save & re-load the AstroData instance with its new WCS attribute:
    testfile = str(tmpdir.join('round_trip_gwcs.fits'))
    ad1.write(testfile)
    ad2 = astrodata.open(testfile)

    wcs1 = ad1[0].nddata.wcs
    wcs2 = ad2[0].nddata.wcs

    # # Temporary workaround for issue #9809, to ensure the test is correct:
    # wcs2.forward_transform[1].x_domain = (0, 9)
    # wcs2.forward_transform[1].y_domain = (0, 9)
    # wcs2.forward_transform[3].domain = (0, 9)
    # wcs2.backward_transform[0].domain = (500.026, 500.126)
    # wcs2.backward_transform[3].x_domain = (-1.5, 12.)
    # wcs2.backward_transform[3].y_domain = (0, 9)

    # Did we actually get a gWCS instance back?
    assert isinstance(wcs2, WCS)

    # Do the transforms have the same number of submodels, with the same types,
    # degrees, domains & parameters? Here the inverse gets checked redundantly
    # as both backward_transform and forward_transform.inverse, but it would be
    # convoluted to ensure that both are correct otherwise (since the transforms
    # get regenerated as new compound models each time they are accessed).
    compare_models(wcs1.forward_transform, wcs2.forward_transform)
    compare_models(wcs1.backward_transform, wcs2.backward_transform)

    # Do the instances have matching co-ordinate frames?
    for f in wcs1.available_frames:
        assert repr(getattr(wcs1, f)) == repr(getattr(wcs2, f))

    # Also compare a few transformed values, as the "proof of the pudding":
    y, x = np.mgrid[0:9:2, 0:9:2]
    np.testing.assert_allclose(wcs1(x, y), wcs2(x, y), rtol=1e-7, atol=0.)

    y, w = np.mgrid[0:9:2, 500.025:500.12:0.0225]
    np.testing.assert_allclose(wcs1.invert(w, y),
                               wcs2.invert(w, y),
                               rtol=1e-7,
                               atol=0.)
예제 #18
0
    def build_interpolated_output_wcs(self, refmodel=None):
        """
        Create a spatial/spectral WCS output frame using all the input models

        Creates output frame by linearly fitting RA, Dec along the slit and
        producing a lookup table to interpolate wavelengths in the dispersion
        direction.

        Parameters
        ----------
        refmodel : `~jwst.datamodels.DataModel`
            The reference input image from which the fiducial WCS is created.
            If not specified, the first image in self.input_models is used.

        Returns
        -------
        output_wcs : `~gwcs.WCS` object
            A gwcs WCS object defining the output frame WCS
        """

        # for each input model convert slit x,y to ra,dec,lam
        # use first input model to set spatial scale
        # use center of appended ra and dec arrays to set up
        # center of final ra,dec
        # append all ra,dec, wavelength array for each slit
        # use first model to initialize wavelenth array
        # append wavelengths that fall outside the endpoint of
        # of wavelength array when looping over additional data

        all_wavelength = []
        all_ra_slit = []
        all_dec_slit = []

        for im, model in enumerate(self.input_models):
            wcs = model.meta.wcs
            bb = wcs.bounding_box
            grid = wcstools.grid_from_bounding_box(bb)
            ra, dec, lam = np.array(wcs(*grid))
            spectral_axis = find_dispersion_axis(model)
            spatial_axis = spectral_axis ^ 1

            # Compute the wavelength array, trimming NaNs from the ends
            # In many cases, a whole slice is NaNs, so ignore those warnings
            warnings.simplefilter("ignore")
            wavelength_array = np.nanmedian(lam, axis=spectral_axis)
            warnings.resetwarnings()
            wavelength_array = wavelength_array[~np.isnan(wavelength_array)]

            # We need to estimate the spatial sampling to use for the output WCS.
            # Tt is assumed the spatial sampling is the same for all the input
            # models. So we can use the first input model to set the spatial
            # sampling.

            # Steps to do this for first input model:
            # 1. find the middle of the spectrum in wavelength
            # 2. Pull out the ra and dec at the center of the slit.
            # 3. Find the mean ra,dec and the center of the slit this will
            #    represent the tangent point
            # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan
            # 5. using x_tan, y_tan perform a linear fit to find spatial sampling
            # first input model sets intializes wavelength array and defines
            # the spatial scale of the output wcs
            if im == 0:
                for iw in wavelength_array:
                    all_wavelength.append(iw)

                lam_center_index = int(
                    (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2)
                if spatial_axis == 0:
                    ra_center = ra[lam_center_index, :]
                    dec_center = dec[lam_center_index, :]
                else:
                    ra_center = ra[:, lam_center_index]
                    dec_center = dec[:, lam_center_index]
                # find the ra and dec for this slit using center of slit
                ra_center_pt = np.nanmean(ra_center)
                dec_center_pt = np.nanmean(dec_center)

                if resample_utils.is_sky_like(model.meta.wcs.output_frame):
                    # convert ra and dec to tangent projection
                    tan = Pix2Sky_TAN()
                    native2celestial = RotateNative2Celestial(
                        ra_center_pt, dec_center_pt, 180)
                    undist2sky1 = tan | native2celestial
                    # Filter out RuntimeWarnings due to computed NaNs in the WCS
                    warnings.simplefilter("ignore")
                    # at this center of slit find x,y tangent projection - x_tan, y_tan
                    x_tan, y_tan = undist2sky1.inverse(ra, dec)
                    warnings.resetwarnings()
                else:
                    # for non sky-like output frames, no need to do tangent plane projections
                    # but we still use the same variables
                    x_tan, y_tan = ra, dec

                # pull out data from center
                if spectral_axis == 0:
                    x_tan_array = x_tan.T[lam_center_index]
                    y_tan_array = y_tan.T[lam_center_index]
                else:
                    x_tan_array = x_tan[lam_center_index]
                    y_tan_array = y_tan[lam_center_index]

                x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
                y_tan_array = y_tan_array[~np.isnan(y_tan_array)]

                # estimate the spatial sampling
                fitter = LinearLSQFitter()
                fit_model = Linear1D()
                xstop = x_tan_array.shape[0] / self.pscale_ratio
                xstep = 1 / self.pscale_ratio
                ystop = y_tan_array.shape[0] / self.pscale_ratio
                ystep = 1 / self.pscale_ratio
                pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep),
                                     x_tan_array)
                pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep),
                                     y_tan_array)

            # append all ra and dec values to use later to find min and max
            # ra and dec
            ra_use = ra.flatten()
            ra_use = ra_use[~np.isnan(ra_use)]
            dec_use = dec.flatten()
            dec_use = dec_use[~np.isnan(dec_use)]
            all_ra_slit.append(ra_use)
            all_dec_slit.append(dec_use)

            # now check wavelength array to see if we need to add to it
            this_minw = np.min(wavelength_array)
            this_maxw = np.max(wavelength_array)
            all_minw = np.min(all_wavelength)
            all_maxw = np.max(all_wavelength)

            if this_minw < all_minw:
                addpts = wavelength_array[wavelength_array < all_minw]
                for ip in range(len(addpts)):
                    all_wavelength.append(addpts[ip])
            if this_maxw > all_maxw:
                addpts = wavelength_array[wavelength_array > all_maxw]
                for ip in range(len(addpts)):
                    all_wavelength.append(addpts[ip])

        # done looping over set of models

        all_ra = np.hstack(all_ra_slit)
        all_dec = np.hstack(all_dec_slit)
        all_wave = np.hstack(all_wavelength)
        all_wave = all_wave[~np.isnan(all_wave)]
        all_wave = np.sort(all_wave, axis=None)
        # Tabular interpolation model, pixels -> lambda
        wavelength_array = np.unique(all_wave)
        # Check if the data is MIRI LRS FIXED Slit. If it is then
        # the wavelength array needs to be flipped so that the resampled
        # dispersion direction matches the disperion direction on the detector.
        if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT':
            wavelength_array = np.flip(wavelength_array, axis=None)

        step = 1 / self.pscale_ratio
        stop = wavelength_array.shape[0] / self.pscale_ratio
        points = np.arange(0, stop, step)
        pix_to_wavelength = Tabular1D(points=points,
                                      lookup_table=wavelength_array,
                                      bounds_error=False,
                                      fill_value=None,
                                      name='pix2wavelength')

        # Tabular models need an inverse explicitly defined.
        # If the wavelength array is decending instead of ascending, both
        # points and lookup_table need to be reversed in the inverse transform
        # for scipy.interpolate to work properly
        points = wavelength_array
        lookup_table = np.arange(0, stop, step)

        if not np.all(np.diff(wavelength_array) > 0):
            points = points[::-1]
            lookup_table = lookup_table[::-1]
        pix_to_wavelength.inverse = Tabular1D(points=points,
                                              lookup_table=lookup_table,
                                              bounds_error=False,
                                              fill_value=None,
                                              name='wavelength2pix')

        # For the input mapping, duplicate the spatial coordinate
        mapping = Mapping((spatial_axis, spatial_axis, spectral_axis))

        # Sometimes the slit is perpendicular to the RA or Dec axis.
        # For example, if the slit is perpendicular to RA, that means
        # the slope of pix_to_xtan will be nearly zero, so make sure
        # mapping.inverse uses pix_to_ytan.inverse.  The auto definition
        # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec.

        if np.isclose(pix_to_ytan.slope, 0, atol=1e-8):
            mapping_tuple = (0, 1)
            # Account for vertical or horizontal dispersion on detector
            if spatial_axis:
                mapping.inverse = Mapping(mapping_tuple[::-1])
            else:
                mapping.inverse = Mapping(mapping_tuple)

        # The final transform
        # redefine the ra, dec center tangent point to include all data

        # check if all_ra crosses 0 degress - this makes it hard to
        # define the min and max ra correctly
        all_ra = wrap_ra(all_ra)
        ra_min = np.amin(all_ra)
        ra_max = np.amax(all_ra)

        ra_center_final = (ra_max + ra_min) / 2.0
        dec_min = np.amin(all_dec)
        dec_max = np.amax(all_dec)
        dec_center_final = (dec_max + dec_min) / 2.0
        tan = Pix2Sky_TAN()
        if len(self.input_models
               ) == 1:  # single model use ra_center_pt to be consistent
            # with how resample was done before
            ra_center_final = ra_center_pt
            dec_center_final = dec_center_pt

        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            native2celestial = RotateNative2Celestial(ra_center_final,
                                                      dec_center_final, 180)
            undist2sky = tan | native2celestial
            # find the spatial size of the output - same in x,y
            x_tan_all, _ = undist2sky.inverse(all_ra, all_dec)
        else:
            x_tan_all, _ = all_ra, all_dec
        x_min = np.amin(x_tan_all)
        x_max = np.amax(x_tan_all)
        x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_xtan.slope)))

        # single model use size of x_tan_array
        # to be consistent with method before
        if len(self.input_models) == 1:
            x_size = len(x_tan_array)

        # define the output wcs
        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            transform = mapping | (pix_to_xtan & pix_to_ytan
                                   | undist2sky) & pix_to_wavelength
        else:
            transform = mapping | (pix_to_xtan
                                   & pix_to_ytan) & pix_to_wavelength

        det = cf.Frame2D(name='detector', axes_order=(0, 1))
        if resample_utils.is_sky_like(model.meta.wcs.output_frame):
            sky = cf.CelestialFrame(name='sky',
                                    axes_order=(0, 1),
                                    reference_frame=coord.ICRS())
        else:
            sky = cf.Frame2D(
                name=f'resampled_{model.meta.wcs.output_frame.name}',
                axes_order=(0, 1))
        spec = cf.SpectralFrame(name='spectral',
                                axes_order=(2, ),
                                unit=(u.micron, ),
                                axes_names=('wavelength', ))
        world = cf.CompositeFrame([sky, spec], name='world')

        pipeline = [(det, transform), (world, None)]

        output_wcs = WCS(pipeline)

        # import ipdb; ipdb.set_trace()

        # compute the output array size in WCS axes order, i.e. (x, y)
        output_array_size = [0, 0]
        output_array_size[spectral_axis] = int(
            np.ceil(len(wavelength_array) / self.pscale_ratio))
        output_array_size[spatial_axis] = int(
            np.ceil(x_size / self.pscale_ratio))
        # turn the size into a numpy shape in (y, x) order
        self.data_size = tuple(output_array_size[::-1])
        bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size)
        output_wcs.bounding_box = bounding_box

        return output_wcs
예제 #19
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        det = x, y = wcstools.grid_from_domain(refwcs.domain)
        sky = ra, dec, lam = refwcs(*det)
        domain_xsize = refwcs.domain[0]['upper'] - refwcs.domain[0]['lower']
        domain_ysize = refwcs.domain[1]['upper'] - refwcs.domain[1]['lower']
        x_center, y_center = int(domain_xsize / 2), int(domain_ysize / 2)

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                    x[y_center, x_center - sz + 1:x_center + sz])
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:,x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
            dec=interpol_dec(slit_center_pix), unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale) & Scale(spatial_scale)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(*sky)
        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
예제 #20
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from bounding_box
        dx0 = refwcs.bounding_box[0][0]
        dx1 = refwcs.bounding_box[0][1]
        dy0 = refwcs.bounding_box[1][0]
        dy1 = refwcs.bounding_box[1][1]
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix,
                                       dy0 + dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0],
                                   dec=slit_center[1],
                                   unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)

        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
예제 #21
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        x, y = wcstools.grid_from_domain(refwcs.domain)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from domain
        dx0 = refwcs.domain[0]['lower']
        dx1 = refwcs.domain[0]['upper']
        dy0 = refwcs.domain[1]['lower']
        dy1 = refwcs.domain[1]['upper']
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 +
            dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1],
            unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(ra, dec, lam)

        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[0]['lower'], domain[0]['upper']))
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew