def test_parameters_compound_models(): tan = Pix2Sky_TAN() sky_coords = coord.SkyCoord(ra=5.6, dec=-72, unit=u.deg) lon_pole = 180 * u.deg n2c = RotateNative2Celestial(sky_coords.ra, sky_coords.dec, lon_pole) rot = Rotation2D(23) m = rot | n2c
def compute_spec_transform(fiducial, refwcs): """ Compute a simple transform given a fidicial point in a spatial-spectral wcs. """ cdelt1 = refwcs.wcsinfo.cdelt1 / 3600. cdelt2 = refwcs.wcsinfo.cdelt2 / 3600. cdelt3 = refwcs.wcsinfo.cdelt3 roll_ref = refwcs.wcsinfo.roll_ref y, x = grid_from_spec_domain(refwcs) ra, dec, lam = refwcs(x, y) min_lam = np.nanmin(lam) offset = Shift(0.) & Shift(0.) rot = Rotation2D(roll_ref) scale = Scale(cdelt1) & Scale(cdelt2) tan = Pix2Sky_TAN() skyrot = RotateNative2Celestial(fiducial[0][0], fiducial[0][1], 180.0) spatial = offset | rot | scale | tan | skyrot spectral = Scale(cdelt3) | Shift(min_lam) mapping = Mapping((1, 1, 0), ) mapping.inverse = Mapping((2, 1)) transform = mapping | spatial & spectral transform.outputs = ('ra', 'dec', 'lamda') return transform
def _make_reference_gwcs_wcs(fits_hdr): hdr = fits.Header.fromfile(get_pkg_data_filename(fits_hdr)) fw = fitswcs.WCS(hdr) unit_conv = Scale(1.0 / 3600.0, name='arcsec_to_deg_1D') unit_conv = unit_conv & unit_conv unit_conv.name = 'arcsec_to_deg_2D' unit_conv_inv = Scale(3600.0, name='deg_to_arcsec_1D') unit_conv_inv = unit_conv_inv & unit_conv_inv unit_conv_inv.name = 'deg_to_arcsec_2D' c2s = CartesianToSpherical(name='c2s', wrap_lon_at=180) s2c = SphericalToCartesian(name='s2c', wrap_lon_at=180) c2tan = ((Mapping((0, 1, 2), name='xyz') / Mapping( (0, 0, 0), n_inputs=3, name='xxx')) | Mapping((1, 2), name='xtyt')) c2tan.name = 'Cartesian 3D to TAN' tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') | (Const1D(1, name='one') & Identity(2, name='I(2D)'))) tan2c.name = 'TAN to cartesian 3D' tan2c.inverse = c2tan c2tan.inverse = tan2c aff = AffineTransformation2D(matrix=fw.wcs.cd) offx = Shift(-fw.wcs.crpix[0]) offy = Shift(-fw.wcs.crpix[1]) s = 5e-6 scale = Scale(s) & Scale(s) det2tan = (offx & offy) | scale | tan2c | c2s | unit_conv_inv taninv = s2c | c2tan tan = Pix2Sky_TAN() n2c = RotateNative2Celestial(fw.wcs.crval[0], fw.wcs.crval[1], 180) wcslin = unit_conv | taninv | scale.inverse | aff | tan | n2c sky_frm = cf.CelestialFrame(reference_frame=coord.ICRS()) det_frm = cf.Frame2D(name='detector') v2v3_frm = cf.Frame2D(name="v2v3", unit=(u.arcsec, u.arcsec), axes_names=('x', 'y'), axes_order=(0, 1)) pipeline = [(det_frm, det2tan), (v2v3_frm, wcslin), (sky_frm, None)] gw = gwcs.WCS(input_frame=det_frm, output_frame=sky_frm, forward_transform=pipeline) gw.crpix = fw.wcs.crpix gw.crval = fw.wcs.crval gw.bounding_box = ((-0.5, fw.pixel_shape[0] - 0.5), (-0.5, fw.pixel_shape[1] - 0.5)) return gw
def map_to_transform(smap): # crval1u, crval2u = smap.reference_coordinate.Tx, smap.reference_coordinate.Ty # cdelt1u, cdelt2u = smap.scale # pcu = smap.rotation_matrix * u.arcsec # # First, shift the reference pixel from FITS (1) to Python (0) indexing # crpix1, crpix2 = u.Quantity(smap.reference_pixel) - 1 * u.pixel # # Then FITS WCS uses the negative of this value as the shift. # shiftu = Shift(-crpix1) & Shift(-crpix2) # # Next we define the Affine Transform. # # This also includes the pixel scale operation by using equivalencies # scale_e = {a: u.pixel_scale(scale) for a, scale in zip('xy', smap.scale)} # rotu = AffineTransformation2D(pcu, translation=(0, 0) * u.arcsec) # rotu.input_units_equivalencies = scale_e # # Handle the projection # tanu = Pix2Sky_TAN() # # Rotate from native spherical to celestial spherical coordinates # skyrotu = RotateNative2Celestial(crval1u, crval2u, 180 * u.deg) # # Combine the whole pipeline into one compound model # transu = shiftu | rotu | tanu | skyrotu crpix1u, crpix2u = u.Quantity(smap.reference_pixel) - 1 * u.pixel crval1u, crval2u = smap.reference_coordinate.Tx, smap.reference_coordinate.Ty cdelt1u, cdelt2u = smap.scale pcu = smap.rotation_matrix * u.arcsec shiftu = Shift(-crpix1u) & Shift(-crpix2u) scaleu = Multiply(cdelt1u) & Multiply(cdelt2u) rotu = AffineTransformation2D(pcu, translation=(0, 0) * u.arcsec) tanu = Pix2Sky_TAN() skyrotu = RotateNative2Celestial(crval1u, crval2u, 180 * u.deg) transu = shiftu | scaleu | rotu | tanu | skyrotu transu.rename("spatial") return transu
def spatial_model_from_quantity(crpix1, crpix2, cdelt1, cdelt2, pc, crval1, crval2, projection='TAN'): """ Given quantity representations of a HPLx FITS WCS return a model for the spatial transform. The ordering of ctype1 and ctype2 should be LON, LAT """ # TODO: Find this from somewhere else or extend it or something projections = {'TAN': Pix2Sky_TAN()} shiftu = Shift(-crpix1) & Shift(-crpix2) scale = Multiply(cdelt1) & Multiply(cdelt2) rotu = AffineTransformation2D(pc, translation=(0, 0) * u.arcsec) tanu = projections[projection] skyrotu = RotateNative2Celestial(crval1, crval2, 180 * u.deg) return shiftu | scale | rotu | tanu | skyrotu
def build_miri_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box, step=(1, 1), center=True) ra, dec, lam = refwcs(x.flatten(), y.flatten()) # TODO: once astropy.modeling._Tabular is fixed, take out the # flatten() and reshape() code above and below ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Find rotation of the slit from y axis from the wcs forward transform # TODO: figure out if angle is necessary for MIRI. See for discussion # https://github.com/STScI-JWST/jwst/pull/347 rotation = [m for m in refwcs.forward_transform if \ isinstance(m, Rotation2D)] if rotation: rot_slit = functools.reduce(lambda x, y: x | y, rotation) rot_angle = rot_slit.inverse.angle.value unrotate = rot_slit.inverse refwcs_minus_rot = refwcs.forward_transform | \ unrotate & Identity(1) # Correct for this rotation in the wcs ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten()) ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Get the slit size at the center of the dispersion sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg) slit_coords = sky_coords[int(sky_coords.shape[0] / 2)] slit_angular_size = slit_coords[0].separation(slit_coords[-1]) log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec)) # Compute slit center from bounding_box dx0 = refwcs.bounding_box[0][0] dx1 = refwcs.bounding_box[0][1] dy0 = refwcs.bounding_box[1][0] dy1 = refwcs.bounding_box[1][1] slit_center_pix = (dx1 - dx0) / 2 dispersion_center_pix = (dy1 - dy0) / 2 slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 + dispersion_center_pix) slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1], unit=u.deg) log.debug('slit center: {0}'.format(slit_center)) # Compute spatial and spectral scales spatial_scale = slit_angular_size / slit_coords.shape[0] log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) tcenter = int((dx1 - dx0) / 2) trace = lam[:, tcenter] trace = trace[~np.isnan(trace)] spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0]) log.debug('spectral scale: {0}'.format(spectral_scale)) # Compute transform for output frame log.debug('Slit center %s' % slit_center_pix) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg rot = Rotation2D(roll_ref) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec, lon_pole) min_lam = np.nanmin(lam) mapping = Mapping((0, 0, 1)) transform = Shift(-slit_center_pix) & Identity(1) | \ Scale(spatial_scale) & Scale(spectral_scale) | \ Identity(1) & Shift(min_lam) | mapping | \ (rot | tan | skyrot) & Identity(1) transform.inputs = (x, y) transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the bounding_box in the output frame wcs object bounding_box_grid = wnew.backward_transform(ra, dec, lam) bounding_box = [] for axis in input_frame.axes_order: axis_min = np.nanmin(bounding_box_grid[axis]) axis_max = np.nanmax(bounding_box_grid[axis]) bounding_box.append((axis_min, axis_max)) wnew.bounding_box = tuple(bounding_box) # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew
def build_nirspec_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs # Generate grid of sky coordinates for area within bounding box bb = refwcs.bounding_box det = x, y = wcstools.grid_from_bounding_box(bb, step=(1, 1), center=True) sky = ra, dec, lam = refwcs(*det) x_center = int((bb[0][1] - bb[0][0]) / 2) y_center = int((bb[1][1] - bb[1][0]) / 2) log.debug("Center of bounding box: {} {}".format(x_center, y_center)) # Compute slit angular size, slit center sky coords xpos = [] sz = 3 for row in lam: if np.isnan(row[x_center]): xpos.append(np.nan) else: f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz], x[y_center, x_center - sz + 1:x_center + sz], bounds_error=False, fill_value='extrapolate') xpos.append(f(lam[y_center, x_center])) x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])] y_arg = y[~np.isnan(lam[:, x_center]), x_center] # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus') slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg) slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg) pix_num = np.flipud(np.arange(len(slit_ra))) # pix_num = np.arange(len(slit_ra)) interpol_ra = interpolate.interp1d(pix_num, slit_ra) interpol_dec = interpolate.interp1d(pix_num, slit_dec) slit_center_pix = len(slit_spec_ref) / 2. - 1 log.debug('Slit center pix: {0}'.format(slit_center_pix)) slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix), dec=interpol_dec(slit_center_pix), unit=u.deg) log.debug('Slit center: {0}'.format(slit_center_sky)) log.debug('Fiducial: {0}'.format( resample_utils.compute_spec_fiducial([refwcs]))) angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1])) log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec)) dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1]) offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec)) log.debug('Offset up the slit: {0}'.format(offset_up_slit)) # Compute spatial and spectral scales xposn = np.array(xpos)[~np.isnan(xpos)] dx = xposn[-1] - xposn[0] slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2) spatial_scale = angular_slit_size / slit_npix log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1] # Compute slit angle relative (clockwise) to y axis slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree) slit_rot_angle = slit_rot_angle.value log.debug('Slit rotation angle: {0}'.format(slit_rot_angle)) # Compute transform for output frame roll_ref = input_model.meta.wcsinfo.roll_ref min_lam = np.nanmin(lam) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles rot = Rotation2D(roll_ref + slit_rot_angle) scale = Scale(spatial_scale.value) & Scale(spatial_scale.value) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra.value, slit_center_sky.dec.value, lon_pole.value) spatial_trans = offset | rot | scale | tan | skyrot spectral_trans = Scale(spectral_scale) | Shift(min_lam) mapping = Mapping((1, 1, 0)) mapping.inverse = Mapping((2, 1)) transform = mapping | spatial_trans & spectral_trans transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the bounding_box in the output frame wcs object bounding_box_grid = wnew.backward_transform(ra, dec, lam) bounding_box = [] for axis in input_frame.axes_order: axis_min = np.nanmin(bounding_box_grid[axis]) axis_max = np.nanmax(bounding_box_grid[axis]) bounding_box.append((axis_min, axis_max)) wnew.bounding_box = tuple(bounding_box) # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame using all the input models Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ # for each input model convert slit x,y to ra,dec,lam # use first input model to set spatial scale # use center of appended ra and dec arrays to set up # center of final ra,dec # append all ra,dec, wavelength array for each slit # use first model to initialize wavelenth array # append wavelengths that fall outside the endpoint of # of wavelength array when looping over additional data all_wavelength = [] all_ra_slit = [] all_dec_slit = [] for im, model in enumerate(self.input_models): wcs = model.meta.wcs bb = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bb) ra, dec, lam = np.array(wcs(*grid)) spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings warnings.simplefilter("ignore") wavelength_array = np.nanmedian(lam, axis=spectral_axis) warnings.resetwarnings() wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # We need to estimate the spatial sampling to use for the output WCS. # Tt is assumed the spatial sampling is the same for all the input # models. So we can use the first input model to set the spatial # sampling. # Steps to do this for first input model: # 1. find the middle of the spectrum in wavelength # 2. Pull out the ra and dec at the center of the slit. # 3. Find the mean ra,dec and the center of the slit this will # represent the tangent point # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan # 5. using x_tan, y_tan perform a linear fit to find spatial sampling # first input model sets intializes wavelength array and defines # the spatial scale of the output wcs if im == 0: for iw in wavelength_array: all_wavelength.append(iw) lam_center_index = int( (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) if spatial_axis == 0: ra_center = ra[lam_center_index, :] dec_center = dec[lam_center_index, :] else: ra_center = ra[:, lam_center_index] dec_center = dec[:, lam_center_index] # find the ra and dec for this slit using center of slit ra_center_pt = np.nanmean(ra_center) dec_center_pt = np.nanmean(dec_center) if resample_utils.is_sky_like(model.meta.wcs.output_frame): # convert ra and dec to tangent projection tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial( ra_center_pt, dec_center_pt, 180) undist2sky1 = tan | native2celestial # Filter out RuntimeWarnings due to computed NaNs in the WCS warnings.simplefilter("ignore") # at this center of slit find x,y tangent projection - x_tan, y_tan x_tan, y_tan = undist2sky1.inverse(ra, dec) warnings.resetwarnings() else: # for non sky-like output frames, no need to do tangent plane projections # but we still use the same variables x_tan, y_tan = ra, dec # pull out data from center if spectral_axis == 0: x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] else: x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] x_tan_array = x_tan_array[~np.isnan(x_tan_array)] y_tan_array = y_tan_array[~np.isnan(y_tan_array)] # estimate the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_tan_array.shape[0] / self.pscale_ratio xstep = 1 / self.pscale_ratio ystop = y_tan_array.shape[0] / self.pscale_ratio ystep = 1 / self.pscale_ratio pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep), x_tan_array) pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep), y_tan_array) # append all ra and dec values to use later to find min and max # ra and dec ra_use = ra.flatten() ra_use = ra_use[~np.isnan(ra_use)] dec_use = dec.flatten() dec_use = dec_use[~np.isnan(dec_use)] all_ra_slit.append(ra_use) all_dec_slit.append(dec_use) # now check wavelength array to see if we need to add to it this_minw = np.min(wavelength_array) this_maxw = np.max(wavelength_array) all_minw = np.min(all_wavelength) all_maxw = np.max(all_wavelength) if this_minw < all_minw: addpts = wavelength_array[wavelength_array < all_minw] for ip in range(len(addpts)): all_wavelength.append(addpts[ip]) if this_maxw > all_maxw: addpts = wavelength_array[wavelength_array > all_maxw] for ip in range(len(addpts)): all_wavelength.append(addpts[ip]) # done looping over set of models all_ra = np.hstack(all_ra_slit) all_dec = np.hstack(all_dec_slit) all_wave = np.hstack(all_wavelength) all_wave = all_wave[~np.isnan(all_wave)] all_wave = np.sort(all_wave, axis=None) # Tabular interpolation model, pixels -> lambda wavelength_array = np.unique(all_wave) # Check if the data is MIRI LRS FIXED Slit. If it is then # the wavelength array needs to be flipped so that the resampled # dispersion direction matches the disperion direction on the detector. if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT': wavelength_array = np.flip(wavelength_array, axis=None) step = 1 / self.pscale_ratio stop = wavelength_array.shape[0] / self.pscale_ratio points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D(points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is decending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_xtan will be nearly zero, so make sure # mapping.inverse uses pix_to_ytan.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. if np.isclose(pix_to_ytan.slope, 0, atol=1e-8): mapping_tuple = (0, 1) # Account for vertical or horizontal dispersion on detector if spatial_axis: mapping.inverse = Mapping(mapping_tuple[::-1]) else: mapping.inverse = Mapping(mapping_tuple) # The final transform # redefine the ra, dec center tangent point to include all data # check if all_ra crosses 0 degress - this makes it hard to # define the min and max ra correctly all_ra = wrap_ra(all_ra) ra_min = np.amin(all_ra) ra_max = np.amax(all_ra) ra_center_final = (ra_max + ra_min) / 2.0 dec_min = np.amin(all_dec) dec_max = np.amax(all_dec) dec_center_final = (dec_max + dec_min) / 2.0 tan = Pix2Sky_TAN() if len(self.input_models ) == 1: # single model use ra_center_pt to be consistent # with how resample was done before ra_center_final = ra_center_pt dec_center_final = dec_center_pt if resample_utils.is_sky_like(model.meta.wcs.output_frame): native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180) undist2sky = tan | native2celestial # find the spatial size of the output - same in x,y x_tan_all, _ = undist2sky.inverse(all_ra, all_dec) else: x_tan_all, _ = all_ra, all_dec x_min = np.amin(x_tan_all) x_max = np.amax(x_tan_all) x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_xtan.slope))) # single model use size of x_tan_array # to be consistent with method before if len(self.input_models) == 1: x_size = len(x_tan_array) # define the output wcs if resample_utils.is_sky_like(model.meta.wcs.output_frame): transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength else: transform = mapping | (pix_to_xtan & pix_to_ytan) & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) if resample_utils.is_sky_like(model.meta.wcs.output_frame): sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) else: sky = cf.Frame2D( name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1)) spec = cf.SpectralFrame(name='spectral', axes_order=(2, ), unit=(u.micron, ), axes_names=('wavelength', )) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # import ipdb; ipdb.set_trace() # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = int( np.ceil(len(wavelength_array) / self.pscale_ratio)) output_array_size[spatial_axis] = int( np.ceil(x_size / self.pscale_ratio)) # turn the size into a numpy shape in (y, x) order self.data_size = tuple(output_array_size[::-1]) bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs
def _make_gwcs_wcs(fits_hdr): hdr = fits.Header.fromfile(get_pkg_data_filename(fits_hdr)) fw = fitswcs.WCS(hdr) a_order = hdr['A_ORDER'] a_coeff = {} for i in range(a_order + 1): for j in range(a_order + 1 - i): key = 'A_{:d}_{:d}'.format(i, j) if key in hdr: a_coeff[key] = hdr[key] b_order = hdr['B_ORDER'] b_coeff = {} for i in range(b_order + 1): for j in range(b_order + 1 - i): key = 'B_{:d}_{:d}'.format(i, j) if key in hdr: b_coeff[key] = hdr[key] cx = {"c" + k[2:]: v for k, v in a_coeff.items()} cy = {"c" + k[2:]: v for k, v in b_coeff.items()} sip_distortion = ((Shift(-fw.wcs.crpix[0]) & Shift(-fw.wcs.crpix[1])) | Mapping((0, 1, 0, 1)) | (polynomial.Polynomial2D(a_order, **cx, c1_0=1) & polynomial.Polynomial2D(b_order, **cy, c0_1=1)) | (Shift(fw.wcs.crpix[0]) & Shift(fw.wcs.crpix[1]))) y, x = np.indices(fw.array_shape) unit_conv = Scale(1.0 / 3600.0, name='arcsec_to_deg_1D') unit_conv = unit_conv & unit_conv unit_conv.name = 'arcsec_to_deg_2D' unit_conv_inv = Scale(3600.0, name='deg_to_arcsec_1D') unit_conv_inv = unit_conv_inv & unit_conv_inv unit_conv_inv.name = 'deg_to_arcsec_2D' c2s = CartesianToSpherical(name='c2s', wrap_lon_at=180) s2c = SphericalToCartesian(name='s2c', wrap_lon_at=180) c2tan = ((Mapping((0, 1, 2), name='xyz') / Mapping( (0, 0, 0), n_inputs=3, name='xxx')) | Mapping((1, 2), name='xtyt')) c2tan.name = 'Cartesian 3D to TAN' tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') | (Const1D(1, name='one') & Identity(2, name='I(2D)'))) tan2c.name = 'TAN to cartesian 3D' tan2c.inverse = c2tan c2tan.inverse = tan2c aff = AffineTransformation2D(matrix=fw.wcs.cd) offx = Shift(-fw.wcs.crpix[0]) offy = Shift(-fw.wcs.crpix[1]) s = 5e-6 scale = Scale(s) & Scale(s) sip_distortion |= (offx & offy) | scale | tan2c | c2s | unit_conv_inv taninv = s2c | c2tan tan = Pix2Sky_TAN() n2c = RotateNative2Celestial(fw.wcs.crval[0], fw.wcs.crval[1], 180) wcslin = unit_conv | taninv | scale.inverse | aff | tan | n2c sky_frm = cf.CelestialFrame(reference_frame=coord.ICRS()) det_frm = cf.Frame2D(name='detector') v2v3_frm = cf.Frame2D(name="v2v3", unit=(u.arcsec, u.arcsec), axes_names=('x', 'y'), axes_order=(0, 1)) pipeline = [(det_frm, sip_distortion), (v2v3_frm, wcslin), (sky_frm, None)] gw = gwcs.WCS(input_frame=det_frm, output_frame=sky_frm, forward_transform=pipeline) gw.crpix = fw.wcs.crpix gw.crval = fw.wcs.crval gw.bounding_box = ((-0.5, fw.pixel_shape[0] - 0.5), (-0.5, fw.pixel_shape[1] - 0.5)) # sanity check: for _ in range(100): x = np.random.randint(1, fw.pixel_shape[0]) y = np.random.randint(1, fw.pixel_shape[1]) assert np.allclose(gw(x, y), fw.all_pix2world(x, y, 1), rtol=0, atol=1e-11) return gw
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ if refmodel is None: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs bb = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bb) ra, dec, lam = np.array(refwcs(*grid)) lon = np.nanmean(ra) lat = np.nanmean(dec) tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial(lon, lat, 180) undist2sky = tan | native2celestial # Filter out RuntimeWarnings due to computed NaNs in the WCS warnings.simplefilter("ignore") x_tan, y_tan = undist2sky.inverse(ra, dec) warnings.resetwarnings() spectral_axis = find_dispersion_axis(refmodel) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends wavelength_array = np.nanmedian(lam, axis=spectral_axis) wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Compute RA and Dec up the slit (spatial direction) at the center # of the dispersion. Use spectral_axis to determine slicing dimension lam_center_index = int( (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) if not spectral_axis: x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] else: x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] x_tan_array = x_tan_array[~np.isnan(x_tan_array)] y_tan_array = y_tan_array[~np.isnan(y_tan_array)] fitter = LinearLSQFitter() fit_model = Linear1D() pix_to_ra = fitter(fit_model, np.arange(x_tan_array.shape[0]), x_tan_array) pix_to_dec = fitter(fit_model, np.arange(y_tan_array.shape[0]), y_tan_array) # Tabular interpolation model, pixels -> lambda pix_to_wavelength = Tabular1D(lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is decending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(wavelength_array.shape[0]) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_ra will be nearly zero, so make sure # mapping.inverse uses pix_to_dec.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. if np.isclose(pix_to_dec.slope, 0, atol=1e-8): mapping_tuple = (0, 1) # Account for vertical or horizontal dispersion on detector if spatial_axis: mapping.inverse = Mapping(mapping_tuple[::-1]) else: mapping.inverse = Mapping(mapping_tuple) # The final transform transform = mapping | (pix_to_ra & pix_to_dec | undist2sky) & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) spec = cf.SpectralFrame(name='spectral', axes_order=(2, ), unit=(u.micron, ), axes_names=('wavelength', )) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = len(wavelength_array) output_array_size[spatial_axis] = len(x_tan_array) # turn the size into a numpy shape in (y, x) order self.data_size = tuple(output_array_size[::-1]) bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs