def build_nirspec_lamp_output_wcs(self): """ Create a spatial/spectral WCS output frame for NIRSpec lamp mode Creates output frame by linearly fitting x_msa, y_msa along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ model = self.input_models[0] wcs = model.meta.wcs bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) x_msa, y_msa, lam = np.array(wcs(*grid)) # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings warnings.simplefilter("ignore") wavelength_array = np.nanmedian(lam, axis=spectral_axis) warnings.resetwarnings() wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Find the center ra and dec for this slit at central wavelength lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) x_msa_array = x_msa.T[lam_center_index] y_msa_array = y_msa.T[lam_center_index] x_msa_array = x_msa_array[~np.isnan(x_msa_array)] y_msa_array = y_msa_array[~np.isnan(y_msa_array)] # Estimate and fit the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_msa_array.shape[0] / self.pscale_ratio xstep = 1 / self.pscale_ratio ystop = y_msa_array.shape[0] / self.pscale_ratio ystep = 1 / self.pscale_ratio pix_to_x_msa = fitter(fit_model, np.arange(0, xstop, xstep), x_msa_array) pix_to_y_msa = fitter(fit_model, np.arange(0, ystop, ystep), y_msa_array) step = 1 / self.pscale_ratio stop = wavelength_array.shape[0] / self.pscale_ratio points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D(points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) mapping.inverse = Mapping((2, 1)) # The final transform # define the output wcs transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.Frame2D(name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1)) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # Compute the output array size and bounding box output_array_size = [0, 0] output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio)) x_size = len(x_msa_array) output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio)) # turn the size into a numpy shape in (y, x) order output_wcs.array_shape = output_array_size[::-1] output_wcs.pixel_shape = output_array_size bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1]) output_wcs.bounding_box = bounding_box return output_wcs
def build_nirspec_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS covering footprint of the input """ all_wcs = [m.meta.wcs for m in self.input_models if m is not refmodel] if refmodel: all_wcs.insert(0, refmodel.meta.wcs) else: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs s2d = refwcs.get_transform('slit_frame', 'detector') d2s = refwcs.get_transform('detector', 'slit_frame') s2w = refwcs.get_transform('slit_frame', 'world') # estimate position of the target without relying in the meta.target: bbox = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) _, s, lam = np.array(d2s(*grid)) sd = s * refmodel.data ld = lam * refmodel.data good_s = np.isfinite(sd) if np.any(good_s): total = np.sum(refmodel.data[good_s]) wmean_s = np.sum(sd[good_s]) / total wmean_l = np.sum(ld[good_s]) / total else: wmean_s = 0.5 * (refmodel.slit_ymax - refmodel.slit_ymin) wmean_l = d2s(*np.mean(bbox, axis=1))[2] targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l) ref_lam = _find_nirspec_output_sampling_wavelengths( all_wcs, targ_ra, targ_dec ) ref_lam = np.array(ref_lam) n_lam = ref_lam.size if not n_lam: raise ValueError("Not enough data to construct output WCS.") x_slit = np.zeros(n_lam) lam = 1e-6 * ref_lam # Find the spatial pixel scale: y_slit_min, y_slit_max = self._max_virtual_slit_extent(all_wcs, targ_ra, targ_dec) nsampl = 50 xy_min = s2d( nsampl * [0], nsampl * [y_slit_min], lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )] ) xy_max = s2d( nsampl * [0], nsampl * [y_slit_max], lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )] ) good = np.logical_and(np.isfinite(xy_min), np.isfinite(xy_max)) if not np.any(good): raise ValueError("Error estimating output WCS pixel scale.") xy1 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymin), lam) xy2 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymax), lam) xylen = np.nanmax(np.linalg.norm(np.array(xy1) - np.array(xy2), axis=0)) + 1 pscale = (refmodel.slit_ymax - refmodel.slit_ymin) / xylen # compute image span along Y-axis (length of the slit in the detector plane) # det_slit_span = np.linalg.norm(np.subtract(xy_max, xy_min)) det_slit_span = np.nanmax(np.linalg.norm(np.subtract(xy_max, xy_min), axis=0)) ny = int(np.ceil(det_slit_span * self.pscale_ratio + 0.5)) + 1 border = 0.5 * (ny - det_slit_span * self.pscale_ratio) - 0.5 if xy_min[1][1] < xy_max[1][1]: y_slit_model = Linear1D( slope=pscale / self.pscale_ratio, intercept=y_slit_min - border * pscale * self.pscale_ratio ) else: y_slit_model = Linear1D( slope=-pscale / self.pscale_ratio, intercept=y_slit_max + border * pscale * self.pscale_ratio ) # extrapolate 1/2 pixel at the edges and make tabular model w/inverse: lam = lam.tolist() pixel_coord = list(range(n_lam)) if len(pixel_coord) > 1: # left: slope = (lam[1] - lam[0]) / pixel_coord[1] lam.insert(0, -0.5 * slope + lam[0]) pixel_coord.insert(0, -0.5) # right: slope = (lam[-1] - lam[-2]) / (pixel_coord[-1] - pixel_coord[-2]) lam.append(slope * (pixel_coord[-1] + 0.5) + lam[-2]) pixel_coord.append(pixel_coord[-1] + 0.5) else: lam = 3 * lam pixel_coord = [-0.5, 0, 0.5] wavelength_transform = Tabular1D(points=pixel_coord, lookup_table=lam, bounds_error=False, fill_value=np.nan) wavelength_transform.inverse = Tabular1D(points=lam, lookup_table=pixel_coord, bounds_error=False, fill_value=np.nan) self.data_size = (ny, len(ref_lam)) # Construct the final transform mapping = Mapping((0, 1, 0)) mapping.inverse = Mapping((2, 1)) out_det2slit = mapping | Identity(1) & y_slit_model & wavelength_transform # Create coordinate frames det = cf.Frame2D(name='detector', axes_order=(0, 1)) slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), unit=("", ""), axes_names=('x_slit', 'y_slit')) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame') sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, out_det2slit), (slit_frame, s2w), (world, None)] output_wcs = WCS(pipeline) # Compute bounding box and output array shape. Add one to the y (slit) # height to account for the half pixel at top and bottom due to pixel # coordinates being centers of pixels bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box output_wcs.array_shape = self.data_size return output_wcs
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame using all the input models Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ # for each input model convert slit x,y to ra,dec,lam # use first input model to set spatial scale # use center of appended ra and dec arrays to set up # center of final ra,dec # append all ra,dec, wavelength array for each slit # use first model to initialize wavelength array # append wavelengths that fall outside the endpoint of # of wavelength array when looping over additional data all_wavelength = [] all_ra_slit = [] all_dec_slit = [] for im, model in enumerate(self.input_models): wcs = model.meta.wcs bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) ra, dec, lam = np.array(wcs(*grid)) # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings warnings.simplefilter("ignore") wavelength_array = np.nanmedian(lam, axis=spectral_axis) warnings.resetwarnings() wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # We need to estimate the spatial sampling to use for the output WCS. # Tt is assumed the spatial sampling is the same for all the input # models. So we can use the first input model to set the spatial # sampling. # Steps to do this for first input model: # 1. find the middle of the spectrum in wavelength # 2. Pull out the ra and dec at the center of the slit. # 3. Find the mean ra,dec and the center of the slit this will # represent the tangent point # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan # 5. using x_tan, y_tan perform a linear fit to find spatial sampling # first input model sets initializes wavelength array and defines # the spatial scale of the output wcs if im == 0: all_wavelength = np.append(all_wavelength, wavelength_array) # find the center ra and dec for this slit at central wavelength lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) if spatial_axis == 0: # MIRI LRS, the WCS x axis is spatial ra_slice = ra[lam_center_index, :] dec_slice = dec[lam_center_index, :] else: ra_slice = ra[:, lam_center_index] dec_slice = dec[:, lam_center_index] # wrap RA if near zero ra_center_pt = np.nanmean(wrap_ra(ra_slice)) dec_center_pt = np.nanmean(dec_slice) # convert ra and dec to tangent projection tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial(ra_center_pt, dec_center_pt, 180) undist2sky1 = tan | native2celestial # Filter out RuntimeWarnings due to computed NaNs in the WCS warnings.simplefilter("ignore") # at this center of slit find x,y tangent projection - x_tan, y_tan x_tan, y_tan = undist2sky1.inverse(ra, dec) warnings.resetwarnings() # pull out data from center if spectral_axis == 0: # MIRI LRS, the WCS x axis is spatial x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] else: x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] x_tan_array = x_tan_array[~np.isnan(x_tan_array)] y_tan_array = y_tan_array[~np.isnan(y_tan_array)] # estimate the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_tan_array.shape[0] / self.pscale_ratio xstep = 1 / self.pscale_ratio ystop = y_tan_array.shape[0] / self.pscale_ratio ystep = 1 / self.pscale_ratio pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep), x_tan_array) pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep), y_tan_array) # append all ra and dec values to use later to find min and max # ra and dec ra_use = ra[~np.isnan(ra)].flatten() dec_use = dec[~np.isnan(dec)].flatten() all_ra_slit = np.append(all_ra_slit, ra_use) all_dec_slit = np.append(all_dec_slit, dec_use) # now check wavelength array to see if we need to add to it this_minw = np.min(wavelength_array) this_maxw = np.max(wavelength_array) all_minw = np.min(all_wavelength) all_maxw = np.max(all_wavelength) if this_minw < all_minw: addpts = wavelength_array[wavelength_array < all_minw] all_wavelength = np.append(all_wavelength, addpts) if this_maxw > all_maxw: addpts = wavelength_array[wavelength_array > all_maxw] all_wavelength = np.append(all_wavelength, addpts) # done looping over set of models all_ra = np.hstack(all_ra_slit) all_dec = np.hstack(all_dec_slit) all_wave = np.hstack(all_wavelength) all_wave = all_wave[~np.isnan(all_wave)] all_wave = np.sort(all_wave, axis=None) # Tabular interpolation model, pixels -> lambda wavelength_array = np.unique(all_wave) # Check if the data is MIRI LRS FIXED Slit. If it is then # the wavelength array needs to be flipped so that the resampled # dispersion direction matches the dispersion direction on the detector. if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT': wavelength_array = np.flip(wavelength_array, axis=None) step = 1 / self.pscale_ratio stop = wavelength_array.shape[0] / self.pscale_ratio points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D(points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_xtan will be nearly zero, so make sure # mapping.inverse uses pix_to_ytan.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. swap_xy = np.isclose(pix_to_xtan.slope, 0, atol=1e-8) if swap_xy: # Account for vertical or horizontal dispersion on detector mapping.inverse = Mapping((2, 1) if spatial_axis else (1, 2)) # The final transform # redefine the ra, dec center tangent point to include all data # check if all_ra crosses 0 degrees - this makes it hard to # define the min and max ra correctly all_ra = wrap_ra(all_ra) ra_min = np.amin(all_ra) ra_max = np.amax(all_ra) ra_center_final = (ra_max + ra_min) / 2.0 dec_min = np.amin(all_dec) dec_max = np.amax(all_dec) dec_center_final = (dec_max + dec_min) / 2.0 tan = Pix2Sky_TAN() if len(self.input_models) == 1: # single model use ra_center_pt to be consistent # with how resample was done before ra_center_final = ra_center_pt dec_center_final = dec_center_pt native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180) undist2sky = tan | native2celestial # find the spatial size of the output - same in x,y if swap_xy: _, x_tan_all = undist2sky.inverse(all_ra, all_dec) pix_to_tan_slope = pix_to_ytan.slope else: x_tan_all, _ = undist2sky.inverse(all_ra, all_dec) pix_to_tan_slope = pix_to_xtan.slope x_min = np.amin(x_tan_all) x_max = np.amax(x_tan_all) x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_tan_slope))) if swap_xy: pix_to_ytan.intercept = -0.5 * (x_size - 1) * pix_to_ytan.slope else: pix_to_xtan.intercept = -0.5 * (x_size - 1) * pix_to_xtan.slope # single model use size of x_tan_array # to be consistent with method before if len(self.input_models) == 1: x_size = len(x_tan_array) # define the output wcs transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio)) output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio)) # turn the size into a numpy shape in (y, x) order output_wcs.array_shape = output_array_size[::-1] output_wcs.pixel_shape = output_array_size bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1]) output_wcs.bounding_box = bounding_box return output_wcs