def wcs_from_spec_footprints(wcslist, refwcs=None, transform=None, domain=None): """ Create a WCS from a list of spatial/spectral WCS. Build-7 workaround. """ if not isiterable(wcslist): raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS") if not all([isinstance(w, WCS) for w in wcslist]): raise TypeError("All items in 'wcslist' must have instance of gwcs.WCS") if refwcs is None: refwcs = wcslist[0] else: if not isinstance(refwcs, WCS): raise TypeError("Expected refwcs to be an instance of gwcs.WCS.") # TODO: generalize an approach to do this for more than one wcs. For # now, we just do it for one, using the api for a list of wcs. # Compute a fiducial point for the output frame at center of input data fiducial = compute_spec_fiducial(wcslist, domain=domain) # Create transform for output frame transform = compute_spec_transform(fiducial, refwcs) output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the domain in the output frame wcs object by running the input wcs # footprints through the backward transform of the output wcs sky = [spec_footprint(w) for w in wcslist] domain_grid = [wnew.backward_transform(*f) for f in sky] sky0 = sky[0] det = domain_grid[0] offsets = [] input_frame = refwcs.input_frame for axis in input_frame.axes_order: axis_min = np.nanmin(det[axis]) offsets.append(axis_min) transform = Shift(offsets[0]) & Shift(offsets[1]) | transform wnew = WCS(output_frame=output_frame, input_frame=input_frame, forward_transform=transform) domain = [] for axis in input_frame.axes_order: axis_min = np.nanmin(domain_grid[0][axis]) axis_max = np.nanmax(domain_grid[0][axis]) + 1 domain.append({'lower': axis_min, 'upper': axis_max, 'includes_lower': True, 'includes_upper': False}) wnew.domain = domain return wnew
def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray]) -> float: """Compute scaling transform. Parameters ---------- wcs : `~gwcs.wcs.WCS` Reference WCS object from which to compute a scaling factor. fiducial : tuple Input fiducial of (RA, DEC) used in calculating reference points. Returns ------- scale : float Scaling factor for x and y. """ if len(fiducial) != 2: raise ValueError(f'Input fiducial must contain only (RA, DEC); Instead recieved: {fiducial}') crpix = np.array(wcs.invert(*fiducial)) crpix_with_offsets = np.vstack((crpix, crpix + (1, 0), crpix + (0, 1))).T crval_with_offsets = wcs(*crpix_with_offsets) coords = SkyCoord(ra=crval_with_offsets[0], dec=crval_with_offsets[1], unit="deg") xscale = np.abs(coords[0].separation(coords[1]).value) yscale = np.abs(coords[0].separation(coords[2]).value) return np.sqrt(xscale * yscale)
def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], disp_axis: int = None, pscale_ratio: float = None) -> float: """Compute scaling transform. Parameters ---------- wcs : `~gwcs.wcs.WCS` Reference WCS object from which to compute a scaling factor. fiducial : tuple Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points. disp_axis : int Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction` pscale_ratio : int Ratio of input to output pixel scale Returns ------- scale : float Scaling factor for x and y or cross-dispersion direction. """ spectral = 'SPECTRAL' in wcs.output_frame.axes_type if spectral and disp_axis is None: raise ValueError('If input WCS is spectral, a disp_axis must be given') crpix = np.array(wcs.invert(*fiducial)) delta = np.zeros_like(crpix) spatial_idx = np.where( np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] delta[spatial_idx[0]] = 1 crpix_with_offsets = np.vstack( (crpix, crpix + delta, crpix + np.roll(delta, 1))).T crval_with_offsets = wcs(*crpix_with_offsets) coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg") xscale = np.abs(coords[0].separation(coords[1]).value) yscale = np.abs(coords[0].separation(coords[2]).value) if pscale_ratio is not None: xscale = xscale * pscale_ratio yscale = yscale * pscale_ratio if spectral: # Assuming scale doesn't change with wavelength # Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction return yscale if disp_axis == 1 else xscale return np.sqrt(xscale * yscale)
def gwcs_1d(): detector_frame = cf.CoordinateFrame( name="detector", naxes=1, axes_order=(0, ), axes_type=("pixel"), axes_names=("x"), unit=(u.pix)) spec_frame = cf.SpectralFrame(name="spectral", axes_order=(2, ), unit=u.nm) return WCS(forward_transform=Identity(1), input_frame=detector_frame, output_frame=spec_frame)
def assign_moving_target_wcs(input_model): if not isinstance(input_model, datamodels.ModelContainer): raise ValueError("Expected a ModelContainer object") # Get the MT RA/Dec values from all the input exposures mt_ra = np.array( [model.meta.wcsinfo.mt_ra for model in input_model._models]) mt_dec = np.array( [model.meta.wcsinfo.mt_dec for model in input_model._models]) # Compute the mean MT RA/Dec over all exposures if (None in mt_ra) or (None in mt_dec): log.warning("One or more MT RA/Dec values missing in input images") log.warning("Step will be skipped, resulting in target misalignment") for model in input_model: model.meta.cal_step.assign_mtwcs = 'SKIPPED' return input_model else: mt_avra = mt_ra.mean() mt_avdec = mt_dec.mean() for model in input_model: pipeline = model.meta.wcs._pipeline[:-1] mt = deepcopy(model.meta.wcs.output_frame) mt.name = 'moving_target' mt_ra = model.meta.wcsinfo.mt_ra mt_dec = model.meta.wcsinfo.mt_dec model.meta.wcsinfo.mt_avra = mt_avra model.meta.wcsinfo.mt_avdec = mt_avdec rdel = mt_avra - mt_ra ddel = mt_avdec - mt_dec if isinstance(mt, cf.CelestialFrame): transform_to_mt = Shift(rdel) & Shift(ddel) elif isinstance(mt, cf.CompositeFrame): transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1) else: raise ValueError("Unrecognized coordinate frame.") pipeline.append((model.meta.wcs.output_frame, transform_to_mt)) pipeline.append((mt, None)) new_wcs = WCS(pipeline) del model.meta.wcs model.meta.wcs = new_wcs model.meta.cal_step.assign_mtwcs = 'COMPLETE' return input_model
def wcs_from_spec_footprints(wcslist, refwcs=None, transform=None, domain=None): """ Create a WCS from a list of spatial/spectral WCS. Build-7 workaround. """ if not isiterable(wcslist): raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS") if not all([isinstance(w, WCS) for w in wcslist]): raise TypeError( "All items in 'wcslist' must have instance of gwcs.WCS") if refwcs is None: refwcs = wcslist[0] else: if not isinstance(refwcs, WCS): raise TypeError("Expected refwcs to be an instance of gwcs.WCS.") # TODO: generalize an approach to do this for more than one wcs. For # now, we just do it for one, using the api for a list of wcs. # Compute a fiducial point for the output frame at center of input data fiducial = compute_spec_fiducial(wcslist, domain=domain) # Create transform for output frame transform = compute_spec_transform(fiducial, refwcs) output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the domain in the output frame wcs object by running the input wcs # footprints through the backward transform of the output wcs sky = [spec_footprint(w) for w in wcslist] domain_grid = [wnew.backward_transform(*f) for f in sky] sky0 = sky[0] det = domain_grid[0] offsets = [] input_frame = refwcs.input_frame for axis in input_frame.axes_order: axis_min = np.nanmin(det[axis]) offsets.append(axis_min) transform = Shift(offsets[0]) & Shift(offsets[1]) | transform wnew = WCS(output_frame=output_frame, input_frame=input_frame, forward_transform=transform) domain = [] for axis in input_frame.axes_order: axis_min = np.nanmin(domain_grid[0][axis]) axis_max = np.nanmax(domain_grid[0][axis]) + 1 domain.append({ 'lower': axis_min, 'upper': axis_max, 'includes_lower': True, 'includes_upper': False }) wnew.domain = domain return wnew
def gwcs_3d(): detector_frame = cf.CoordinateFrame( name="detector", naxes=3, axes_order=(0, 1, 2), axes_type=("pixel", "pixel", "pixel"), axes_names=("x", "y", "z"), unit=(u.pix, u.pix, u.pix)) sky_frame = cf.CelestialFrame(reference_frame=Helioprojective(), name='hpc') spec_frame = cf.SpectralFrame(name="spectral", axes_order=(2, ), unit=u.nm) out_frame = cf.CompositeFrame(frames=(sky_frame, spec_frame)) return WCS(forward_transform=spatial_like_model(), input_frame=detector_frame, output_frame=out_frame)
def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec): """ Add a "moving_target" frame to the WCS pipeline. Parameters ---------- wcs : `~gwcs.WCS` WCS object for the observation or slit. ra_average : float The average RA of all observations. dec_average : float The average DEC of all observations. mt_ra, mt_dec : float The RA, DEC of the moving target in the observation. Returns ------- new_wcs : `~gwcs.WCS` The WCS for the moving target observation. """ pipeline = wcs._pipeline[:-1] mt = deepcopy(wcs.output_frame) mt.name = 'moving_target' rdel = ra_average - mt_ra ddel = dec_average - mt_dec if isinstance(mt, cf.CelestialFrame): transform_to_mt = Shift(rdel) & Shift(ddel) elif isinstance(mt, cf.CompositeFrame): transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1) else: raise ValueError("Unrecognized coordinate frame.") pipeline.append(( wcs.output_frame, transform_to_mt)) pipeline.append((mt, None)) new_wcs = WCS(pipeline) return new_wcs
def assign_moving_target_wcs(input_model): if not isinstance(input_model, datamodels.ModelContainer): raise ValueError("Expected a ModelContainer object") mt_ra = np.array( [model.meta.wcsinfo.mt_ra for model in input_model._models]) mt_dec = np.array( [model.meta.wcsinfo.mt_dec for model in input_model._models]) mt_avra = mt_ra.mean() mt_avdec = mt_dec.mean() for model in input_model: pipeline = model.meta.wcs._pipeline[:-1] mt = deepcopy(model.meta.wcs.output_frame) mt.name = 'moving_target' mt_ra = model.meta.wcsinfo.mt_ra mt_dec = model.meta.wcsinfo.mt_dec model.meta.wcsinfo.mt_avra = mt_avra model.meta.wcsinfo.mt_avdec = mt_avdec rdel = mt_avra - mt_ra ddel = mt_avdec - mt_dec if isinstance(mt, cf.CelestialFrame): transform_to_mt = Shift(rdel) & Shift(ddel) elif isinstance(mt, cf.CompositeFrame): transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1) else: raise ValueError("Unrecognized coordinate frame.") pipeline.append((model.meta.wcs.output_frame, transform_to_mt)) pipeline.append((mt, None)) new_wcs = WCS(pipeline) del model.meta.wcs model.meta.wcs = new_wcs model.meta.cal_step.assign_mtwcs = 'COMPLETE' return input_model
def build_nirspec_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS covering footprint of the input """ all_wcs = [m.meta.wcs for m in self.input_models if m is not refmodel] if refmodel: all_wcs.insert(0, refmodel.meta.wcs) else: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs s2d = refwcs.get_transform('slit_frame', 'detector') d2s = refwcs.get_transform('detector', 'slit_frame') s2w = refwcs.get_transform('slit_frame', 'world') # estimate position of the target without relying in the meta.target: bbox = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) _, s, lam = np.array(d2s(*grid)) sd = s * refmodel.data ld = lam * refmodel.data good_s = np.isfinite(sd) if np.any(good_s): total = np.sum(refmodel.data[good_s]) wmean_s = np.sum(sd[good_s]) / total wmean_l = np.sum(ld[good_s]) / total else: wmean_s = 0.5 * (refmodel.slit_ymax - refmodel.slit_ymin) wmean_l = d2s(*np.mean(bbox, axis=1))[2] targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l) ref_lam = _find_nirspec_output_sampling_wavelengths( all_wcs, targ_ra, targ_dec ) ref_lam = np.array(ref_lam) n_lam = ref_lam.size if not n_lam: raise ValueError("Not enough data to construct output WCS.") x_slit = np.zeros(n_lam) lam = 1e-6 * ref_lam # Find the spatial pixel scale: y_slit_min, y_slit_max = self._max_virtual_slit_extent(all_wcs, targ_ra, targ_dec) nsampl = 50 xy_min = s2d( nsampl * [0], nsampl * [y_slit_min], lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )] ) xy_max = s2d( nsampl * [0], nsampl * [y_slit_max], lam[(tuple((i * n_lam) // nsampl for i in range(nsampl)), )] ) good = np.logical_and(np.isfinite(xy_min), np.isfinite(xy_max)) if not np.any(good): raise ValueError("Error estimating output WCS pixel scale.") xy1 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymin), lam) xy2 = s2d(x_slit, np.full(n_lam, refmodel.slit_ymax), lam) xylen = np.nanmax(np.linalg.norm(np.array(xy1) - np.array(xy2), axis=0)) + 1 pscale = (refmodel.slit_ymax - refmodel.slit_ymin) / xylen # compute image span along Y-axis (length of the slit in the detector plane) # det_slit_span = np.linalg.norm(np.subtract(xy_max, xy_min)) det_slit_span = np.nanmax(np.linalg.norm(np.subtract(xy_max, xy_min), axis=0)) ny = int(np.ceil(det_slit_span * self.pscale_ratio + 0.5)) + 1 border = 0.5 * (ny - det_slit_span * self.pscale_ratio) - 0.5 if xy_min[1][1] < xy_max[1][1]: y_slit_model = Linear1D( slope=pscale / self.pscale_ratio, intercept=y_slit_min - border * pscale * self.pscale_ratio ) else: y_slit_model = Linear1D( slope=-pscale / self.pscale_ratio, intercept=y_slit_max + border * pscale * self.pscale_ratio ) # extrapolate 1/2 pixel at the edges and make tabular model w/inverse: lam = lam.tolist() pixel_coord = list(range(n_lam)) if len(pixel_coord) > 1: # left: slope = (lam[1] - lam[0]) / pixel_coord[1] lam.insert(0, -0.5 * slope + lam[0]) pixel_coord.insert(0, -0.5) # right: slope = (lam[-1] - lam[-2]) / (pixel_coord[-1] - pixel_coord[-2]) lam.append(slope * (pixel_coord[-1] + 0.5) + lam[-2]) pixel_coord.append(pixel_coord[-1] + 0.5) else: lam = 3 * lam pixel_coord = [-0.5, 0, 0.5] wavelength_transform = Tabular1D(points=pixel_coord, lookup_table=lam, bounds_error=False, fill_value=np.nan) wavelength_transform.inverse = Tabular1D(points=lam, lookup_table=pixel_coord, bounds_error=False, fill_value=np.nan) self.data_size = (ny, len(ref_lam)) # Construct the final transform mapping = Mapping((0, 1, 0)) mapping.inverse = Mapping((2, 1)) out_det2slit = mapping | Identity(1) & y_slit_model & wavelength_transform # Create coordinate frames det = cf.Frame2D(name='detector', axes_order=(0, 1)) slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), unit=("", ""), axes_names=('x_slit', 'y_slit')) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame') sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, out_det2slit), (slit_frame, s2w), (world, None)] output_wcs = WCS(pipeline) # Compute bounding box and output array shape. Add one to the y (slit) # height to account for the half pixel at top and bottom due to pixel # coordinates being centers of pixels bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box output_wcs.array_shape = self.data_size return output_wcs
def build_nirspec_lamp_output_wcs(self): """ Create a spatial/spectral WCS output frame for NIRSpec lamp mode Creates output frame by linearly fitting x_msa, y_msa along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ model = self.input_models[0] wcs = model.meta.wcs bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) x_msa, y_msa, lam = np.array(wcs(*grid)) # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings warnings.simplefilter("ignore") wavelength_array = np.nanmedian(lam, axis=spectral_axis) warnings.resetwarnings() wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Find the center ra and dec for this slit at central wavelength lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) x_msa_array = x_msa.T[lam_center_index] y_msa_array = y_msa.T[lam_center_index] x_msa_array = x_msa_array[~np.isnan(x_msa_array)] y_msa_array = y_msa_array[~np.isnan(y_msa_array)] # Estimate and fit the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_msa_array.shape[0] / self.pscale_ratio xstep = 1 / self.pscale_ratio ystop = y_msa_array.shape[0] / self.pscale_ratio ystep = 1 / self.pscale_ratio pix_to_x_msa = fitter(fit_model, np.arange(0, xstop, xstep), x_msa_array) pix_to_y_msa = fitter(fit_model, np.arange(0, ystop, ystep), y_msa_array) step = 1 / self.pscale_ratio stop = wavelength_array.shape[0] / self.pscale_ratio points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D(points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) mapping.inverse = Mapping((2, 1)) # The final transform # define the output wcs transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.Frame2D(name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1)) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # Compute the output array size and bounding box output_array_size = [0, 0] output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array) / self.pscale_ratio)) x_size = len(x_msa_array) output_array_size[spatial_axis] = int(np.ceil(x_size / self.pscale_ratio)) # turn the size into a numpy shape in (y, x) order output_wcs.array_shape = output_array_size[::-1] output_wcs.pixel_shape = output_array_size bounding_box = resample_utils.wcs_bbox_from_shape(output_array_size[::-1]) output_wcs.bounding_box = bounding_box return output_wcs
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ if refmodel is None: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs bb = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bb) ra, dec, lam = np.array(refwcs(*grid)) spectral_axis = find_dispersion_axis(lam) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends wavelength_array = np.nanmedian(lam, axis=spectral_axis) wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Compute RA and Dec up the slit (spatial direction) at the center # of the dispersion. Use spectral_axis to determine slicing dimension lam_center_index = int((bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) if not spectral_axis: ra_array = ra.T[lam_center_index] dec_array = dec.T[lam_center_index] else: ra_array = ra[lam_center_index] dec_array = dec[lam_center_index] ra_array = ra_array[~np.isnan(ra_array)] dec_array = dec_array[~np.isnan(dec_array)] fitter = LinearLSQFitter() fit_model = Linear1D() pix_to_ra = fitter(fit_model, np.arange(ra_array.shape[0]), ra_array) pix_to_dec = fitter(fit_model, np.arange(dec_array.shape[0]), dec_array) # Tabular interpolation model, pixels -> lambda pix_to_wavelength = Tabular1D(lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is decending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(wavelength_array.shape[0]) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_ra will be nearly zero, so make sure # mapping.inverse uses pix_to_dec.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. if np.isclose(pix_to_dec.slope, 0, atol=1e-8): mapping_tuple = (0, 1) # Account for vertical or horizontal dispersion on detector if spatial_axis: mapping.inverse = Mapping(mapping_tuple[::-1]) else: mapping.inverse = Mapping(mapping_tuple) # The final transform transform = mapping | pix_to_ra & pix_to_dec & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), axes_names=('wavelength',)) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = len(wavelength_array) output_array_size[spatial_axis] = len(ra_array) # turn the size into a numpy shape in (y, x) order self.data_size = tuple(output_array_size[::-1]) bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ if refmodel is None: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs bb = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bb) ra, dec, lam = np.array(refwcs(*grid)) spectral_axis = find_dispersion_axis(lam) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends wavelength_array = np.nanmedian(lam, axis=spectral_axis) wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Compute RA and Dec up the slit (spatial direction) at the center # of the dispersion. Use spectral_axis to determine slicing dimension lam_center_index = int( (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) if not spectral_axis: ra_array = ra.T[lam_center_index] dec_array = dec.T[lam_center_index] else: ra_array = ra[lam_center_index] dec_array = dec[lam_center_index] ra_array = ra_array[~np.isnan(ra_array)] dec_array = dec_array[~np.isnan(dec_array)] fitter = LinearLSQFitter() fit_model = Linear1D() pix_to_ra = fitter(fit_model, np.arange(ra_array.shape[0]), ra_array) pix_to_dec = fitter(fit_model, np.arange(dec_array.shape[0]), dec_array) # Tabular interpolation model, pixels -> lambda pix_to_wavelength = Tabular1D(lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is decending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(wavelength_array.shape[0]) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_ra will be nearly zero, so make sure # mapping.inverse uses pix_to_dec.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. if np.isclose(pix_to_dec.slope, 0, atol=1e-8): mapping_tuple = (0, 1) # Account for vertical or horizontal dispersion on detector if spatial_axis: mapping.inverse = Mapping(mapping_tuple[::-1]) else: mapping.inverse = Mapping(mapping_tuple) # The final transform transform = mapping | pix_to_ra & pix_to_dec & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) spec = cf.SpectralFrame(name='spectral', axes_order=(2, ), unit=(u.micron, ), axes_names=('wavelength', )) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = len(wavelength_array) output_array_size[spatial_axis] = len(ra_array) # turn the size into a numpy shape in (y, x) order self.data_size = tuple(output_array_size[::-1]) bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs
def build_nirspec_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS covering footprint of the input """ if not refmodel: refmodel = self.input_models[0] refwcs = refmodel.meta.wcs bb = refwcs.bounding_box ref_det2slit = refwcs.get_transform('detector', 'slit_frame') ref_slit2world = refwcs.get_transform('slit_frame', 'world') grid = x, y = wcstools.grid_from_bounding_box(bb, step=(1, 1)) Grid = namedtuple('Grid', refwcs.slit_frame.axes_names) grid_slit = Grid(*ref_det2slit(*grid)) # Compute spatial transform from detector to slit ref_wavelength = np.nanmean(grid_slit.wavelength) # find the number of pixels sampled by a single shutter fid = np.array([[0., 0.], [-.5, .5], np.repeat(ref_wavelength, 2)]) slit_extents = np.array(ref_det2slit.inverse(*fid)).T pix_per_shutter = np.linalg.norm(slit_extents[0] - slit_extents[1]) # Get min and max of slit in pixel units ymin = np.nanmin(grid_slit.y_slit) * pix_per_shutter ymax = np.nanmax(grid_slit.y_slit) * pix_per_shutter slit_height_pix = int(abs(ymax - ymin + 0.5)) # Compute grid of wavelengths and make tabular model w/inverse lookup_table = np.nanmean(grid_slit.wavelength, axis=0) wavelength_transform = Tabular1D(lookup_table=lookup_table, bounds_error=False, fill_value=np.nan) wavelength_transform.inverse = Tabular1D( points=lookup_table, lookup_table=np.arange(grid_slit.wavelength.shape[1]), bounds_error=False, fill_value=np.nan) # Define detector to slit transforms yslit_transform = Scale(-1 / pix_per_shutter) | Shift( ymax / pix_per_shutter) xslit_transform = Const1D(-0.) xslit_transform.inverse = Const1D(0.) # Construct the final transform coord_mapping = Mapping((0, 1, 0)) coord_mapping.inverse = Mapping((2, 1)) the_transform = xslit_transform & yslit_transform & wavelength_transform out_det2slit = coord_mapping | the_transform # Create coordinate frames det = cf.Frame2D(name='detector', axes_order=(0, 1)) slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), unit=("", ""), axes_names=('x_slit', 'y_slit')) spec = cf.SpectralFrame(name='spectral', axes_order=(2, ), unit=(u.micron, ), axes_names=('wavelength', )) slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame') sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, out_det2slit), (slit_frame, ref_slit2world), (world, None)] output_wcs = WCS(pipeline) self.data_size = (slit_height_pix, len(lookup_table)) bounding_box = resample_utils.bounding_box_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs
def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None): """ Create a WCS from a list of input data models. A fiducial point in the output coordinate frame is created from the footprints of all WCS objects. For a spatial frame this is the center of the union of the footprints. For a spectral frame the fiducial is in the beginning of the footprint range. If ``refmodel`` is None, the first WCS object in the list is considered a reference. The output coordinate frame and projection (for celestial frames) is taken from ``refmodel``. If ``transform`` is not suplied, a compound transform is created using CDELTs and PC. If ``bounding_box`` is not supplied, the bounding_box of the new WCS is computed from bounding_box of all input WCSs. Parameters ---------- dmodels : list of `~jwst.datamodels.DataModel` A list of data models. refmodel : `~jwst.datamodels.DataModel`, optional This model's WCS is used as a reference. WCS. The output coordinate frame, the projection and a scaling and rotation transform is created from it. If not supplied the first model in the list is used as ``refmodel``. transform : `~astropy.modeling.core.Model`, optional A transform, passed to :meth:`~gwcs.wcstools.wcs_from_fiducial` If not supplied Scaling | Rotation is computed from ``refmodel``. bounding_box : tuple, optional Bounding_box of the new WCS. If not supplied it is computed from the bounding_box of all inputs. """ bb = bounding_box wcslist = [im.meta.wcs for im in dmodels] if not isiterable(wcslist): raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.") if not all([isinstance(w, WCS) for w in wcslist]): raise TypeError("All items in wcslist are to be instances of gwcs.WCS.") if refmodel is None: refmodel = dmodels[0] else: if not isinstance(refmodel, DataModel): raise TypeError("Expected refmodel to be an instance of DataModel.") fiducial = compute_fiducial(wcslist, bb) ref_fiducial = compute_fiducial([refmodel.meta.wcs]) prj = astmodels.Pix2Sky_TAN() if transform is None: transform = [] wcsinfo = pointing.wcsinfo_from_model(refmodel) sky_axes, spec, other = gwutils.get_axes(wcsinfo) # Need to put the rotation matrix (List[float, float, float, float]) returned from calc_rotation_matrix into the # correct shape for constructing the transformation pc = np.reshape( calc_rotation_matrix( np.deg2rad(refmodel.meta.wcsinfo.roll_ref), np.deg2rad(refmodel.meta.wcsinfo.v3yangle), vparity=refmodel.meta.wcsinfo.vparity), (2, 2) ) rotation = astmodels.AffineTransformation2D(pc) transform.append(rotation) if sky_axes: scale = compute_scale(refmodel.meta.wcs, ref_fiducial) transform.append(astmodels.Scale(scale) & astmodels.Scale(scale)) if transform: transform = functools.reduce(lambda x, y: x | y, transform) out_frame = refmodel.meta.wcs.output_frame input_frame = dmodels[0].meta.wcs.input_frame wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, transform=transform) # temporary fix before gwcs 0.14 is released # wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, # transform=transform, input_frame=input_frame) pipe = wnew.pipeline[:] pipe[0] = (input_frame, pipe[0][1]) wnew = WCS(pipe) footprints = [w.footprint().T for w in wcslist] domain_bounds = np.hstack([wnew.backward_transform(*f) for f in footprints]) for axs in domain_bounds: axs -= (axs.min() + .5) bounding_box = [] for axis in out_frame.axes_order: axis_min, axis_max = domain_bounds[axis].min(), domain_bounds[axis].max() bounding_box.append((axis_min, axis_max)) bounding_box = tuple(bounding_box) ax1, ax2 = np.array(bounding_box)[sky_axes] offset1 = (ax1[1] - ax1[0]) / 2 offset2 = (ax2[1] - ax2[0]) / 2 offsets = astmodels.Shift(-offset1) & astmodels.Shift(-offset2) wnew.insert_transform('detector', offsets, after=True) wnew.bounding_box = bounding_box return wnew
def build_nirspec_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs # Generate grid of sky coordinates for area within bounding box bb = refwcs.bounding_box det = x, y = wcstools.grid_from_bounding_box(bb, step=(1, 1), center=True) sky = ra, dec, lam = refwcs(*det) x_center = int((bb[0][1] - bb[0][0]) / 2) y_center = int((bb[1][1] - bb[1][0]) / 2) log.debug("Center of bounding box: {} {}".format(x_center, y_center)) # Compute slit angular size, slit center sky coords xpos = [] sz = 3 for row in lam: if np.isnan(row[x_center]): xpos.append(np.nan) else: f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz], x[y_center, x_center - sz + 1:x_center + sz], bounds_error=False, fill_value='extrapolate') xpos.append(f(lam[y_center, x_center])) x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])] y_arg = y[~np.isnan(lam[:, x_center]), x_center] # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus') slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg) slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg) pix_num = np.flipud(np.arange(len(slit_ra))) # pix_num = np.arange(len(slit_ra)) interpol_ra = interpolate.interp1d(pix_num, slit_ra) interpol_dec = interpolate.interp1d(pix_num, slit_dec) slit_center_pix = len(slit_spec_ref) / 2. - 1 log.debug('Slit center pix: {0}'.format(slit_center_pix)) slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix), dec=interpol_dec(slit_center_pix), unit=u.deg) log.debug('Slit center: {0}'.format(slit_center_sky)) log.debug('Fiducial: {0}'.format( resample_utils.compute_spec_fiducial([refwcs]))) angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1])) log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec)) dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1]) offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec)) log.debug('Offset up the slit: {0}'.format(offset_up_slit)) # Compute spatial and spectral scales xposn = np.array(xpos)[~np.isnan(xpos)] dx = xposn[-1] - xposn[0] slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2) spatial_scale = angular_slit_size / slit_npix log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1] # Compute slit angle relative (clockwise) to y axis slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree) slit_rot_angle = slit_rot_angle.value log.debug('Slit rotation angle: {0}'.format(slit_rot_angle)) # Compute transform for output frame roll_ref = input_model.meta.wcsinfo.roll_ref min_lam = np.nanmin(lam) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles rot = Rotation2D(roll_ref + slit_rot_angle) scale = Scale(spatial_scale.value) & Scale(spatial_scale.value) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra.value, slit_center_sky.dec.value, lon_pole.value) spatial_trans = offset | rot | scale | tan | skyrot spectral_trans = Scale(spectral_scale) | Shift(min_lam) mapping = Mapping((1, 1, 0)) mapping.inverse = Mapping((2, 1)) transform = mapping | spatial_trans & spectral_trans transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the bounding_box in the output frame wcs object bounding_box_grid = wnew.backward_transform(ra, dec, lam) bounding_box = [] for axis in input_frame.axes_order: axis_min = np.nanmin(bounding_box_grid[axis]) axis_max = np.nanmax(bounding_box_grid[axis]) bounding_box.append((axis_min, axis_max)) wnew.bounding_box = tuple(bounding_box) # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew
def test_round_trip_gwcs(tmpdir): """ Add a 2-step gWCS instance to NDAstroData, save to disk, reload & compare. """ from gwcs import coordinate_frames as cf from gwcs import WCS arr = np.zeros((10, 10), dtype=np.float32) ad1 = astrodata.create(fits.PrimaryHDU(), [fits.ImageHDU(arr, name='SCI')]) # Transformation from detector pixels to pixels in some reference row, # removing relative distortions in wavelength: det_frame = cf.Frame2D(name='det_mosaic', axes_names=('x', 'y'), unit=(u.pix, u.pix)) dref_frame = cf.Frame2D(name='dist_ref_row', axes_names=('xref', 'y'), unit=(u.pix, u.pix)) # A made-up example model that looks vaguely like some real distortions: fdist = models.Chebyshev2D(2, 2, c0_0=4.81125, c1_0=5.43375, c0_1=-0.135, c1_1=-0.405, c0_2=0.30375, c1_2=0.91125, x_domain=[0., 9.], y_domain=[0., 9.]) # This is not an accurate inverse, but will do for this test: idist = models.Chebyshev2D(2, 2, c0_0=4.89062675, c1_0=5.68581232, c2_0=-0.00590263, c0_1=0.11755526, c1_1=0.35652358, c2_1=-0.01193828, c0_2=-0.29996306, c1_2=-0.91823397, c2_2=0.02390594, x_domain=[-1.5, 12.], y_domain=[0., 9.]) # The resulting 2D co-ordinate mapping from detector to ref row pixels: distrans = models.Mapping((0, 1, 1)) | (fdist & models.Identity(1)) distrans.inverse = models.Mapping((0, 1, 1)) | (idist & models.Identity(1)) # Transformation from reference row pixels to linear, row-stacked spectra: spec_frame = cf.SpectralFrame(axes_order=(0, ), unit=u.nm, axes_names='lambda', name='wavelength') row_frame = cf.CoordinateFrame(1, 'SPATIAL', axes_order=(1, ), unit=u.pix, axes_names='y', name='row') rss_frame = cf.CompositeFrame([spec_frame, row_frame]) # Toy wavelength model & approximate inverse: fwcal = models.Chebyshev1D(2, c0=500.075, c1=0.05, c2=0.001, domain=[0, 9]) iwcal = models.Chebyshev1D(2, c0=4.59006292, c1=4.49601817, c2=-0.08989608, domain=[500.026, 500.126]) # The resulting 2D co-ordinate mapping from ref pixels to wavelength: wavtrans = fwcal & models.Identity(1) wavtrans.inverse = iwcal & models.Identity(1) # The complete WCS chain for these 2 transformation steps: ad1[0].nddata.wcs = WCS([(det_frame, distrans), (dref_frame, wavtrans), (rss_frame, None)]) # Save & re-load the AstroData instance with its new WCS attribute: testfile = str(tmpdir.join('round_trip_gwcs.fits')) ad1.write(testfile) ad2 = astrodata.open(testfile) wcs1 = ad1[0].nddata.wcs wcs2 = ad2[0].nddata.wcs # # Temporary workaround for issue #9809, to ensure the test is correct: # wcs2.forward_transform[1].x_domain = (0, 9) # wcs2.forward_transform[1].y_domain = (0, 9) # wcs2.forward_transform[3].domain = (0, 9) # wcs2.backward_transform[0].domain = (500.026, 500.126) # wcs2.backward_transform[3].x_domain = (-1.5, 12.) # wcs2.backward_transform[3].y_domain = (0, 9) # Did we actually get a gWCS instance back? assert isinstance(wcs2, WCS) # Do the transforms have the same number of submodels, with the same types, # degrees, domains & parameters? Here the inverse gets checked redundantly # as both backward_transform and forward_transform.inverse, but it would be # convoluted to ensure that both are correct otherwise (since the transforms # get regenerated as new compound models each time they are accessed). compare_models(wcs1.forward_transform, wcs2.forward_transform) compare_models(wcs1.backward_transform, wcs2.backward_transform) # Do the instances have matching co-ordinate frames? for f in wcs1.available_frames: assert repr(getattr(wcs1, f)) == repr(getattr(wcs2, f)) # Also compare a few transformed values, as the "proof of the pudding": y, x = np.mgrid[0:9:2, 0:9:2] np.testing.assert_allclose(wcs1(x, y), wcs2(x, y), rtol=1e-7, atol=0.) y, w = np.mgrid[0:9:2, 500.025:500.12:0.0225] np.testing.assert_allclose(wcs1.invert(w, y), wcs2.invert(w, y), rtol=1e-7, atol=0.)
def build_interpolated_output_wcs(self, refmodel=None): """ Create a spatial/spectral WCS output frame using all the input models Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Parameters ---------- refmodel : `~jwst.datamodels.DataModel` The reference input image from which the fiducial WCS is created. If not specified, the first image in self.input_models is used. Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ # for each input model convert slit x,y to ra,dec,lam # use first input model to set spatial scale # use center of appended ra and dec arrays to set up # center of final ra,dec # append all ra,dec, wavelength array for each slit # use first model to initialize wavelenth array # append wavelengths that fall outside the endpoint of # of wavelength array when looping over additional data all_wavelength = [] all_ra_slit = [] all_dec_slit = [] for im, model in enumerate(self.input_models): wcs = model.meta.wcs bb = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bb) ra, dec, lam = np.array(wcs(*grid)) spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings warnings.simplefilter("ignore") wavelength_array = np.nanmedian(lam, axis=spectral_axis) warnings.resetwarnings() wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # We need to estimate the spatial sampling to use for the output WCS. # Tt is assumed the spatial sampling is the same for all the input # models. So we can use the first input model to set the spatial # sampling. # Steps to do this for first input model: # 1. find the middle of the spectrum in wavelength # 2. Pull out the ra and dec at the center of the slit. # 3. Find the mean ra,dec and the center of the slit this will # represent the tangent point # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan # 5. using x_tan, y_tan perform a linear fit to find spatial sampling # first input model sets intializes wavelength array and defines # the spatial scale of the output wcs if im == 0: for iw in wavelength_array: all_wavelength.append(iw) lam_center_index = int( (bb[spectral_axis][1] - bb[spectral_axis][0]) / 2) if spatial_axis == 0: ra_center = ra[lam_center_index, :] dec_center = dec[lam_center_index, :] else: ra_center = ra[:, lam_center_index] dec_center = dec[:, lam_center_index] # find the ra and dec for this slit using center of slit ra_center_pt = np.nanmean(ra_center) dec_center_pt = np.nanmean(dec_center) if resample_utils.is_sky_like(model.meta.wcs.output_frame): # convert ra and dec to tangent projection tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial( ra_center_pt, dec_center_pt, 180) undist2sky1 = tan | native2celestial # Filter out RuntimeWarnings due to computed NaNs in the WCS warnings.simplefilter("ignore") # at this center of slit find x,y tangent projection - x_tan, y_tan x_tan, y_tan = undist2sky1.inverse(ra, dec) warnings.resetwarnings() else: # for non sky-like output frames, no need to do tangent plane projections # but we still use the same variables x_tan, y_tan = ra, dec # pull out data from center if spectral_axis == 0: x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] else: x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] x_tan_array = x_tan_array[~np.isnan(x_tan_array)] y_tan_array = y_tan_array[~np.isnan(y_tan_array)] # estimate the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_tan_array.shape[0] / self.pscale_ratio xstep = 1 / self.pscale_ratio ystop = y_tan_array.shape[0] / self.pscale_ratio ystep = 1 / self.pscale_ratio pix_to_xtan = fitter(fit_model, np.arange(0, xstop, xstep), x_tan_array) pix_to_ytan = fitter(fit_model, np.arange(0, ystop, ystep), y_tan_array) # append all ra and dec values to use later to find min and max # ra and dec ra_use = ra.flatten() ra_use = ra_use[~np.isnan(ra_use)] dec_use = dec.flatten() dec_use = dec_use[~np.isnan(dec_use)] all_ra_slit.append(ra_use) all_dec_slit.append(dec_use) # now check wavelength array to see if we need to add to it this_minw = np.min(wavelength_array) this_maxw = np.max(wavelength_array) all_minw = np.min(all_wavelength) all_maxw = np.max(all_wavelength) if this_minw < all_minw: addpts = wavelength_array[wavelength_array < all_minw] for ip in range(len(addpts)): all_wavelength.append(addpts[ip]) if this_maxw > all_maxw: addpts = wavelength_array[wavelength_array > all_maxw] for ip in range(len(addpts)): all_wavelength.append(addpts[ip]) # done looping over set of models all_ra = np.hstack(all_ra_slit) all_dec = np.hstack(all_dec_slit) all_wave = np.hstack(all_wavelength) all_wave = all_wave[~np.isnan(all_wave)] all_wave = np.sort(all_wave, axis=None) # Tabular interpolation model, pixels -> lambda wavelength_array = np.unique(all_wave) # Check if the data is MIRI LRS FIXED Slit. If it is then # the wavelength array needs to be flipped so that the resampled # dispersion direction matches the disperion direction on the detector. if self.input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT': wavelength_array = np.flip(wavelength_array, axis=None) step = 1 / self.pscale_ratio stop = wavelength_array.shape[0] / self.pscale_ratio points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D(points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name='pix2wavelength') # Tabular models need an inverse explicitly defined. # If the wavelength array is decending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D(points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name='wavelength2pix') # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_xtan will be nearly zero, so make sure # mapping.inverse uses pix_to_ytan.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. if np.isclose(pix_to_ytan.slope, 0, atol=1e-8): mapping_tuple = (0, 1) # Account for vertical or horizontal dispersion on detector if spatial_axis: mapping.inverse = Mapping(mapping_tuple[::-1]) else: mapping.inverse = Mapping(mapping_tuple) # The final transform # redefine the ra, dec center tangent point to include all data # check if all_ra crosses 0 degress - this makes it hard to # define the min and max ra correctly all_ra = wrap_ra(all_ra) ra_min = np.amin(all_ra) ra_max = np.amax(all_ra) ra_center_final = (ra_max + ra_min) / 2.0 dec_min = np.amin(all_dec) dec_max = np.amax(all_dec) dec_center_final = (dec_max + dec_min) / 2.0 tan = Pix2Sky_TAN() if len(self.input_models ) == 1: # single model use ra_center_pt to be consistent # with how resample was done before ra_center_final = ra_center_pt dec_center_final = dec_center_pt if resample_utils.is_sky_like(model.meta.wcs.output_frame): native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180) undist2sky = tan | native2celestial # find the spatial size of the output - same in x,y x_tan_all, _ = undist2sky.inverse(all_ra, all_dec) else: x_tan_all, _ = all_ra, all_dec x_min = np.amin(x_tan_all) x_max = np.amax(x_tan_all) x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_xtan.slope))) # single model use size of x_tan_array # to be consistent with method before if len(self.input_models) == 1: x_size = len(x_tan_array) # define the output wcs if resample_utils.is_sky_like(model.meta.wcs.output_frame): transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength else: transform = mapping | (pix_to_xtan & pix_to_ytan) & pix_to_wavelength det = cf.Frame2D(name='detector', axes_order=(0, 1)) if resample_utils.is_sky_like(model.meta.wcs.output_frame): sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) else: sky = cf.Frame2D( name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1)) spec = cf.SpectralFrame(name='spectral', axes_order=(2, ), unit=(u.micron, ), axes_names=('wavelength', )) world = cf.CompositeFrame([sky, spec], name='world') pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # import ipdb; ipdb.set_trace() # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = int( np.ceil(len(wavelength_array) / self.pscale_ratio)) output_array_size[spatial_axis] = int( np.ceil(x_size / self.pscale_ratio)) # turn the size into a numpy shape in (y, x) order self.data_size = tuple(output_array_size[::-1]) bounding_box = resample_utils.wcs_bbox_from_shape(self.data_size) output_wcs.bounding_box = bounding_box return output_wcs
def build_nirspec_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs # Generate grid of sky coordinates for area within domain det = x, y = wcstools.grid_from_domain(refwcs.domain) sky = ra, dec, lam = refwcs(*det) domain_xsize = refwcs.domain[0]['upper'] - refwcs.domain[0]['lower'] domain_ysize = refwcs.domain[1]['upper'] - refwcs.domain[1]['lower'] x_center, y_center = int(domain_xsize / 2), int(domain_ysize / 2) # Compute slit angular size, slit center sky coords xpos = [] sz = 3 for row in lam: if np.isnan(row[x_center]): xpos.append(np.nan) else: f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz], x[y_center, x_center - sz + 1:x_center + sz]) xpos.append(f(lam[y_center, x_center])) x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])] y_arg = y[~np.isnan(lam[:,x_center]), x_center] # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus') slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg) slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg) pix_num = np.flipud(np.arange(len(slit_ra))) # pix_num = np.arange(len(slit_ra)) interpol_ra = interpolate.interp1d(pix_num, slit_ra) interpol_dec = interpolate.interp1d(pix_num, slit_dec) slit_center_pix = len(slit_spec_ref) / 2. - 1 log.debug('Slit center pix: {0}'.format(slit_center_pix)) slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix), dec=interpol_dec(slit_center_pix), unit=u.deg) log.debug('Slit center: {0}'.format(slit_center_sky)) log.debug('Fiducial: {0}'.format(resample_utils.compute_spec_fiducial([refwcs]))) angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1])) log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec)) dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1]) offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec)) log.debug('Offset up the slit: {0}'.format(offset_up_slit)) # Compute spatial and spectral scales xposn = np.array(xpos)[~np.isnan(xpos)] dx = xposn[-1] - xposn[0] slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2) spatial_scale = angular_slit_size / slit_npix log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1] # Compute slit angle relative (clockwise) to y axis slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree) log.debug('Slit rotation angle: {0}'.format(slit_rot_angle)) # Compute transform for output frame roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg min_lam = np.nanmin(lam) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles rot = Rotation2D(roll_ref + slit_rot_angle) scale = Scale(spatial_scale) & Scale(spatial_scale) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec, lon_pole) spatial_trans = offset | rot | scale | tan | skyrot spectral_trans = Scale(spectral_scale) | Shift(min_lam) mapping = Mapping((1, 1, 0)) mapping.inverse = Mapping((2, 1)) transform = mapping | spatial_trans & spectral_trans transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the domain in the output frame wcs object domain_grid = wnew.backward_transform(*sky) domain = [] for axis in input_frame.axes_order: axis_min = np.nanmin(domain_grid[axis]) axis_max = np.nanmax(domain_grid[axis]) + 1 domain.append({'lower': axis_min, 'upper': axis_max, 'includes_lower': True, 'includes_upper': False}) log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper'])) wnew.domain = domain # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew
def build_miri_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box, step=(1, 1), center=True) ra, dec, lam = refwcs(x.flatten(), y.flatten()) # TODO: once astropy.modeling._Tabular is fixed, take out the # flatten() and reshape() code above and below ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Find rotation of the slit from y axis from the wcs forward transform # TODO: figure out if angle is necessary for MIRI. See for discussion # https://github.com/STScI-JWST/jwst/pull/347 rotation = [m for m in refwcs.forward_transform if \ isinstance(m, Rotation2D)] if rotation: rot_slit = functools.reduce(lambda x, y: x | y, rotation) rot_angle = rot_slit.inverse.angle.value unrotate = rot_slit.inverse refwcs_minus_rot = refwcs.forward_transform | \ unrotate & Identity(1) # Correct for this rotation in the wcs ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten()) ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Get the slit size at the center of the dispersion sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg) slit_coords = sky_coords[int(sky_coords.shape[0] / 2)] slit_angular_size = slit_coords[0].separation(slit_coords[-1]) log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec)) # Compute slit center from bounding_box dx0 = refwcs.bounding_box[0][0] dx1 = refwcs.bounding_box[0][1] dy0 = refwcs.bounding_box[1][0] dy1 = refwcs.bounding_box[1][1] slit_center_pix = (dx1 - dx0) / 2 dispersion_center_pix = (dy1 - dy0) / 2 slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 + dispersion_center_pix) slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1], unit=u.deg) log.debug('slit center: {0}'.format(slit_center)) # Compute spatial and spectral scales spatial_scale = slit_angular_size / slit_coords.shape[0] log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) tcenter = int((dx1 - dx0) / 2) trace = lam[:, tcenter] trace = trace[~np.isnan(trace)] spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0]) log.debug('spectral scale: {0}'.format(spectral_scale)) # Compute transform for output frame log.debug('Slit center %s' % slit_center_pix) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg rot = Rotation2D(roll_ref) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec, lon_pole) min_lam = np.nanmin(lam) mapping = Mapping((0, 0, 1)) transform = Shift(-slit_center_pix) & Identity(1) | \ Scale(spatial_scale) & Scale(spectral_scale) | \ Identity(1) & Shift(min_lam) | mapping | \ (rot | tan | skyrot) & Identity(1) transform.inputs = (x, y) transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the bounding_box in the output frame wcs object bounding_box_grid = wnew.backward_transform(ra, dec, lam) bounding_box = [] for axis in input_frame.axes_order: axis_min = np.nanmin(bounding_box_grid[axis]) axis_max = np.nanmax(bounding_box_grid[axis]) bounding_box.append((axis_min, axis_max)) wnew.bounding_box = tuple(bounding_box) # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew
def build_miri_output_wcs(self, refwcs=None): """ Create a simple output wcs covering footprint of the input datamodels """ # TODO: generalize this for more than one input datamodel # TODO: generalize this for imaging modes with distorted wcs input_model = self.input_models[0] if refwcs == None: refwcs = input_model.meta.wcs # Generate grid of sky coordinates for area within domain x, y = wcstools.grid_from_domain(refwcs.domain) ra, dec, lam = refwcs(x.flatten(), y.flatten()) # TODO: once astropy.modeling._Tabular is fixed, take out the # flatten() and reshape() code above and below ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Find rotation of the slit from y axis from the wcs forward transform # TODO: figure out if angle is necessary for MIRI. See for discussion # https://github.com/STScI-JWST/jwst/pull/347 rotation = [m for m in refwcs.forward_transform if \ isinstance(m, Rotation2D)] if rotation: rot_slit = functools.reduce(lambda x, y: x | y, rotation) rot_angle = rot_slit.inverse.angle.value unrotate = rot_slit.inverse refwcs_minus_rot = refwcs.forward_transform | \ unrotate & Identity(1) # Correct for this rotation in the wcs ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten()) ra = ra.reshape(x.shape) dec = dec.reshape(x.shape) lam = lam.reshape(x.shape) # Get the slit size at the center of the dispersion sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg) slit_coords = sky_coords[int(sky_coords.shape[0] / 2)] slit_angular_size = slit_coords[0].separation(slit_coords[-1]) log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec)) # Compute slit center from domain dx0 = refwcs.domain[0]['lower'] dx1 = refwcs.domain[0]['upper'] dy0 = refwcs.domain[1]['lower'] dy1 = refwcs.domain[1]['upper'] slit_center_pix = (dx1 - dx0) / 2 dispersion_center_pix = (dy1 - dy0) / 2 slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 + dispersion_center_pix) slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1], unit=u.deg) log.debug('slit center: {0}'.format(slit_center)) # Compute spatial and spectral scales spatial_scale = slit_angular_size / slit_coords.shape[0] log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec)) tcenter = int((dx1 - dx0) / 2) trace = lam[:, tcenter] trace = trace[~np.isnan(trace)] spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0]) log.debug('spectral scale: {0}'.format(spectral_scale)) # Compute transform for output frame log.debug('Slit center %s' % slit_center_pix) offset = Shift(-slit_center_pix) & Shift(-slit_center_pix) # TODO: double-check the signs on the following rotation angles roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg rot = Rotation2D(roll_ref) tan = Pix2Sky_TAN() lon_pole = _compute_lon_pole(slit_center_sky, tan) skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec, lon_pole) min_lam = np.nanmin(lam) mapping = Mapping((0, 0, 1)) transform = Shift(-slit_center_pix) & Identity(1) | \ Scale(spatial_scale) & Scale(spectral_scale) | \ Identity(1) & Shift(min_lam) | mapping | \ (rot | tan | skyrot) & Identity(1) transform.inputs = (x, y) transform.outputs = ('ra', 'dec', 'lamda') # Build the output wcs input_frame = refwcs.input_frame output_frame = refwcs.output_frame wnew = WCS(output_frame=output_frame, forward_transform=transform) # Build the domain in the output frame wcs object domain_grid = wnew.backward_transform(ra, dec, lam) domain = [] for axis in input_frame.axes_order: axis_min = np.nanmin(domain_grid[axis]) axis_max = np.nanmax(domain_grid[axis]) + 1 domain.append({'lower': axis_min, 'upper': axis_max, 'includes_lower': True, 'includes_upper': False}) log.debug('Domain: {0} {1}'.format(domain[0]['lower'], domain[0]['upper'])) log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper'])) wnew.domain = domain # Update class properties self.output_spatial_scale = spatial_scale self.output_spectral_scale = spectral_scale self.output_wcs = wnew