def slit_data_a(): """Create "science" data for testing. Returns ------- input_model : `~jwst.datamodels.MultiSlitModel` """ # Create a MultiSlitModel object. data_shape = (5, 9) data = np.zeros(data_shape, dtype=np.float32) + 10. dq = np.zeros(data_shape, dtype=np.uint32) # One row of wavelengths. temp_wl = np.linspace(1.3, 4.8, num=data_shape[1], endpoint=True, retstep=False, dtype=np.float32) wavelength = np.zeros(data_shape, dtype=np.float32) # Shift the wavelength values from row to the next. dwl = 0.1 for j in range(data_shape[0]): wavelength[j, :] = (temp_wl + j * dwl) wavelength = np.around(wavelength, 4) input_model = datamodels.MultiSlitModel() slit = datamodels.SlitModel(init=None, data=data, dq=dq, wavelength=wavelength) input_model.slits.append(slit) return input_model
def test_has_uniform_source(): data = np.zeros((10, 100), dtype=np.float32) slitlet = datamodels.SlitModel(data=data) # Since source_type has not been set yet, the barshadow step will # assume that the source is extended. assert bar.has_uniform_source(slitlet) slitlet.source_type = 'POINT' assert not bar.has_uniform_source(slitlet) # not extended slitlet.source_type = 'UNKNOWN' # Since source_type is not 'POINT', the step will assume that the # source is extended. assert bar.has_uniform_source(slitlet)
def test_fs_correction(): """Test application of FS corrections""" data = np.ones((5, 5)) ff_ps = 1.5 * data ff_un = data / 1.2 pl_ps = 2 * data pl_un = data / 2.1 ph_ps = 1.1 * data ph_un = 1.23 * data input = datamodels.SlitModel(data=data, flatfield_point=ff_ps, flatfield_uniform=ff_un, pathloss_point=pl_ps, pathloss_uniform=pl_un, photom_point=ph_ps, photom_uniform=ph_un) corrected = input.data * (ff_un / ff_ps) * (pl_un / pl_ps) * (ph_ps / ph_un) result = correct_nrs_fs_bkg(input, primary_slit=True) assert np.allclose(corrected, result.data, rtol=1.e-7)
def _corrections_for_fixedslit(slit, pathloss, exp_type, source_type): """Calculate the correction arrasy for Fixed-slit slit Parameters ---------- slit : jwst.datamodels.SlitModel The slit being operated on. pathloss : jwst.datamodels.DataModel The pathloss reference data exp_type : str Exposure type source_type : str or None Force processing using the specified source type. Returns ------- correction : jwst.datamodels.SlitModel The correction arrays """ correction = None # Get centering xcenter, ycenter = get_center(exp_type, slit) # Calculate the 1-d wavelength and pathloss vectors for the source position # Get the aperture from the reference file that matches the slit aperture = get_aperture_from_model(pathloss, slit.name) if aperture is not None: log.info(f'Using aperture {aperture.name}') (wavelength_pointsource, pathloss_pointsource_vector, is_inside_slit) = calculate_pathloss_vector(aperture.pointsource_data, aperture.pointsource_wcs, xcenter, ycenter) (wavelength_uniformsource, pathloss_uniform_vector, dummy) = calculate_pathloss_vector(aperture.uniform_data, aperture.uniform_wcs, xcenter, ycenter) if is_inside_slit: # Wavelengths in the reference file are in meters, # need them to be in microns wavelength_pointsource *= 1.0e6 wavelength_uniformsource *= 1.0e6 wavelength_array = slit.wavelength # Compute the point source pathloss 2D correction pathloss_2d_ps = interpolate_onto_grid( wavelength_array, wavelength_pointsource, pathloss_pointsource_vector) # Compute the uniform source pathloss 2D correction pathloss_2d_un = interpolate_onto_grid(wavelength_array, wavelength_uniformsource, pathloss_uniform_vector) # Use the appropriate correction for this slit if is_pointsource(source_type or slit.source_type): pathloss_2d = pathloss_2d_ps else: pathloss_2d = pathloss_2d_un # Save the corrections. The `data` portion is the correction used. # The individual ones will be saved in the respective attributes. correction = datamodels.SlitModel(data=pathloss_2d) correction.pathloss_point = pathloss_2d_ps correction.pathloss_uniform = pathloss_2d_un else: log.warning('Source is outside slit. Skipping ' f'pathloss correction for slit {slit.name}') else: log.warning(f'Cannot find matching pathloss model for {slit.name}') log.warning('Skipping pathloss correction for this slit') return correction
def _corrections_for_mos(slit, pathloss, exp_type, source_type=None): """Calculate the correction arrasy for MOS slit Parameters ---------- slit : jwst.datamodels.SlitModel The slit being operated on. pathloss : jwst.datamodels.DataModel The pathloss reference data exp_type : str Exposure type source_type : str or None Force processing using the specified source type. Returns ------- correction : jwst.datamodels.SlitModel The correction arrays """ correction = None size = slit.data.size # Only work on slits with data.size > 0 if size > 0: # Get centering xcenter, ycenter = get_center(exp_type, slit) # Calculate the 1-d wavelength and pathloss vectors # for the source position # Get the aperture from the reference file that matches the slit nshutters = util.get_num_msa_open_shutters(slit.shutter_state) aperture = get_aperture_from_model(pathloss, nshutters) if aperture is not None: (wavelength_pointsource, pathloss_pointsource_vector, is_inside_slitlet) = calculate_pathloss_vector( aperture.pointsource_data, aperture.pointsource_wcs, xcenter, ycenter) (wavelength_uniformsource, pathloss_uniform_vector, dummy) = calculate_pathloss_vector(aperture.uniform_data, aperture.uniform_wcs, xcenter, ycenter) if is_inside_slitlet: # Wavelengths in the reference file are in meters, # need them to be in microns wavelength_pointsource *= 1.0e6 wavelength_uniformsource *= 1.0e6 wavelength_array = slit.wavelength # Compute the point source pathloss 2D correction pathloss_2d_ps = interpolate_onto_grid( wavelength_array, wavelength_pointsource, pathloss_pointsource_vector) # Compute the uniform source pathloss 2D correction pathloss_2d_un = interpolate_onto_grid( wavelength_array, wavelength_uniformsource, pathloss_uniform_vector) # Use the appropriate correction for this slit if is_pointsource(source_type or slit.source_type): pathloss_2d = pathloss_2d_ps else: pathloss_2d = pathloss_2d_un # Save the corrections. The `data` portion is the correction used. # The individual ones will be saved in the respective attributes. correction = datamodels.SlitModel(data=pathloss_2d) correction.pathloss_point = pathloss_2d_ps correction.pathloss_uniform = pathloss_2d_un else: log.warning("Source is outside slit.") else: log.warning("Cannot find matching pathloss model for slit with" f"{nshutters} shutters") else: log.warning(f"Slit has data size = {size}") return correction
def _calc_correction(slitlet, barshadow_model, source_type): """Calculate the barshadow correction for a slitlet Parameters ---------- slitlet : jwst.datamodels.SlitModel The slitlet to calculate for. barshadow_model : `~jwst.datamodels.BarshadowModel` bar shadow data model from reference file source_type : str or None Force processing using the specified source type. Returns ------- correction : jwst.datamodels.SlitModel The correction to be applied """ slitlet_number = slitlet.slitlet_id # Create the pieces that are put together to make the barshadow model shutter_elements = create_shutter_elements(barshadow_model) w0 = barshadow_model.crval1 wave_increment = barshadow_model.cdelt1 y_increment = barshadow_model.cdelt2 shutter_height = 1.0 / y_increment # The correction only applies to extended/uniform sources correction = datamodels.SlitModel(data=np.ones(slitlet.data.shape)) if has_uniform_source(slitlet, source_type): shutter_status = slitlet.shutter_state if len(shutter_status) > 0: shadow = create_shadow(shutter_elements, shutter_status) # For each pixel in the slit subarray, # make a grid of indices for pixels in the subarray x, y = wcstools.grid_from_bounding_box( slitlet.meta.wcs.bounding_box, step=(1, 1)) # Create the transformation from slit_frame to detector det2slit = slitlet.meta.wcs.get_transform('detector', 'slit_frame') # Use this transformation to calculate x, y, and wavelength xslit, yslit, wavelength = det2slit(x, y) # If the source position is off-center in the slit, renormalize the yslit # values so that it appears as if the source is centered, which is the appropriate # way to compute the shadow correction for extended/uniform sources (doesn't # depend on source location). if len(shutter_status) > 1: middle = (len(shutter_status) - 1) / 2.0 src_loc = shutter_status.find('x') if src_loc != -1 and float(src_loc) != middle: yslit -= np.nanmean(yslit) # The returned y values are scaled to where the slit height is 1 # (i.e. a slit goes from -0.5 to 0.5). The barshadow array is scaled # so that the separation between the slit centers is 1, # i.e. slit height + interslit bar yslit = yslit / SLITRATIO # Convert the Y and wavelength to a pixel location in the bar shadow array; # the fiducial should always be at the center of the slitlet, regardless of # where the source is centered. index_of_fiducial = (len(shutter_status) - 1) / 2.0 # The shutters go downwards, i.e. the first shutter in shutter_status corresponds to # the last in the shadow array. So the center of the first shutter referred to in # shutter_status has an index of shadow.shape[0] - shutter_height. Each subsequent # shutter center has an index shutter_height greater. index_of_fiducial_in_array = shadow.shape[0] - shutter_height * ( 1 + index_of_fiducial) yrow = index_of_fiducial_in_array + yslit * shutter_height wcol = (wavelength - w0) / wave_increment # Interpolate the bar shadow correction for non-Nan pixels correction.data = interpolate(yrow, wcol, shadow) else: log.info("Slitlet %d has zero length, correction skipped" % slitlet_number) else: log.info( "Bar shadow correction skipped for slitlet %d (source not uniform)" % slitlet_number) return correction
def contam_corr(input_model, waverange, photom, max_cores): """ The main WFSS contamination correction function Parameters ---------- input_model : `~jwst.datamodels.MultiSlitModel` Input data model containing 2D spectral cutouts waverange : `~jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model max_cores : string Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). Returns ------- output_model : `~jwst.datamodels.MultiSlitModel` A copy of the input_model that has been corrected simul_model : `~jwst.datamodels.ImageModel` Full-frame simulated image of the grism exposure contam_model : `~jwst.datamodels.MultiSlitModel` Contamination estimate images for each source slit """ # Determine number of cpu's to use for multi-processing if max_cores == 'none': ncpus = 1 else: num_cores = multiprocessing.cpu_count() if max_cores == 'quarter': ncpus = num_cores // 4 or 1 elif max_cores == 'half': ncpus = num_cores // 2 or 1 elif max_cores == 'all': ncpus = num_cores else: ncpus = 1 log.debug(f"Found {num_cores} cores; using {ncpus}") # Initialize output model output_model = input_model.copy() # Get the segmentation map for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) # Get the direct image from which the segmentation map was constructed direct_file = input_model.meta.direct_image image_names = [direct_file] log.debug(f"Direct image names={image_names}") # Get the grism WCS from the input model grism_wcs = input_model.slits[0].meta.wcs # Find out how many spectral orders are defined, based on the # array of order values in the Wavelengthrange ref file spec_orders = np.asarray(waverange.order) spec_orders = spec_orders[spec_orders != 0] # ignore any order 0 entries log.debug(f"Spectral orders defined = {spec_orders}") # Get the FILTER and PUPIL wheel positions, for use later filter_kwd = input_model.meta.instrument.filter pupil_kwd = input_model.meta.instrument.pupil # NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel # with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just # the opposite. It has gratings in the FILTER wheel and filters in the # PUPIL wheel. So when processing NIRISS grism exposures the name of # filter needs to come from the PUPIL keyword value. if input_model.meta.instrument.name == 'NIRISS': filter_name = pupil_kwd else: filter_name = filter_kwd # Load lists of wavelength ranges and flux cal info for all orders wmin = {} wmax = {} sens_waves = {} sens_response = {} for order in spec_orders: wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) wmin[order] = wavelength_range[order][0] wmax[order] = wavelength_range[order][1] # Load the sensitivity (inverse flux cal) data for this mode and order sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order) log.debug(f"wmin={wmin}, wmax={wmax}") # Initialize the simulated image object simul_all = None obs = Observation(image_names, seg_model, grism_wcs, filter_name, boundaries=[0, 2047, 0, 2047], max_cpu=ncpus) # Create simulated grism image for each order and sum them up for order in spec_orders: log.info(f"Creating full simulated grism image for order {order}") obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) # Accumulate result for this order into the combined image if simul_all is None: simul_all = obs.simulated_image else: simul_all += obs.simulated_image # Save the full-frame simulated grism image simul_model = datamodels.ImageModel(data=simul_all) simul_model.update(input_model, only="PRIMARY") # Loop over all slits/sources to subtract contaminating spectra log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() contam_model.update(input_model) slits = [] for slit in output_model.slits: # Create simulated spectrum for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order chunk = np.where(obs.IDs == sid)[0][0] # find chunk for this source obs.simulated_image = np.zeros(obs.dims) obs.disperse_chunk(chunk, order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) this_source = obs.simulated_image # Contamination estimate is full simulated image minus this source contam = simul_all - this_source # Create a cutout of the contam image that matches the extent # of the source slit x1 = slit.xstart - 1 x2 = x1 + slit.xsize y1 = slit.ystart - 1 y2 = y1 + slit.ysize cutout = contam[y1:y2, x1:x2] new_slit = datamodels.SlitModel(data=cutout) copy_slit_info(slit, new_slit) slits.append(new_slit) # Subtract the cutout from the source slit slit.data -= cutout # Save the contamination estimates for all slits contam_model.slits.extend(slits) # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = 'COMPLETE' return output_model, simul_model, contam_model