def slit_data_a(): """Create "science" data for testing. Returns ------- input_model : `~jwst.datamodels.MultiSlitModel` """ # Create a MultiSlitModel object. data_shape = (5, 9) data = np.zeros(data_shape, dtype=np.float32) + 10. dq = np.zeros(data_shape, dtype=np.uint32) # One row of wavelengths. temp_wl = np.linspace(1.3, 4.8, num=data_shape[1], endpoint=True, retstep=False, dtype=np.float32) wavelength = np.zeros(data_shape, dtype=np.float32) # Shift the wavelength values from row to the next. dwl = 0.1 for j in range(data_shape[0]): wavelength[j, :] = (temp_wl + j * dwl) wavelength = np.around(wavelength, 4) input_model = datamodels.MultiSlitModel() slit = datamodels.SlitModel(init=None, data=data, dq=dq, wavelength=wavelength) input_model.slits.append(slit) return input_model
def test_nrs_msaspec(): """Test for when exposure type is NRS_MSASPEC """ input = datamodels.MultiSlitModel() input.meta.exposure.type = "NRS_MSASPEC" slits = [{ 'source_id': 1, 'stellarity': 0.9 }, { 'source_id': 2, 'stellarity': -1 }, { 'source_id': 3, 'stellarity': 0.5 }] for slit in slits: input.slits.append(slit) result = srctype.set_source_type(input) assert (result.slits[0].source_type == 'POINT') assert (result.slits[1].source_type == 'POINT') assert (result.slits[2].source_type == 'EXTENDED')
def __init__(self, input_DM): """ Short Summary ------------- Set file name of persistence file Parameters ---------- input_DM: data model object input Data Model object """ try: self.input_file = input_DM model = datamodels.open(input_DM) # If model comes back as generic DataModel, reopen as MultiSlit if isinstance(model, datamodels.CubeModel) or isinstance( model, datamodels.ImageModel): pass elif isinstance(model, datamodels.DataModel): model.close() model = datamodels.MultiSlitModel(input_DM) self.input = model except Exception as errmess: log.error('Error opening %s', input_DM) self.input = None
def create_background_from_multislit(input_model): """Create a 1D master background spectrum from a set of calibrated background MOS slitlets in the input MultiSlitModel. Parameters ---------- input_model : `~jwst.datamodels.MultiSlitModel` The input data model containing all slit instances. Returns ------- master_bkg: `~jwst.datamodels.CombinedSpecModel` The 1D master background spectrum created from the inputs. """ from ..resample import resample_spec_step from ..extract_1d import extract_1d_step from ..combine_1d.combine1d import combine_1d_spectra log.info('Creating MOS master background from background slitlets') # Copy dedicated background slitlets to a temporary model bkg_model = datamodels.MultiSlitModel() bkg_model.update(input_model) slits = [] for slit in input_model.slits: if "background" in slit.source_name: log.info(f'Using background slitlet {slit.source_name}') slits.append(slit) if len(slits) == 0: log.warning( 'No background slitlets found; skipping master bkg correction') return None bkg_model.slits.extend(slits) # Apply resample_spec and extract_1d to all background slitlets log.info('Applying resampling and 1D extraction to background slits') resamp = resample_spec_step.ResampleSpecStep.call(bkg_model) x1d = extract_1d_step.Extract1dStep.call(resamp) # Call combine_1d to combine the 1D background spectra log.info('Combining 1D background spectra into master background') master_bkg = combine_1d_spectra(x1d, exptime_key='exposure_time') del bkg_model del resamp del x1d return master_bkg
def test_nrs_fixedslit(): """Test for when exposure type is NRS_FIXEDSLIT """ input = datamodels.MultiSlitModel() input.meta.exposure.type = "NRS_FIXEDSLIT" input.meta.instrument.fixed_slit = "S200A1" input.meta.target.source_type = 'EXTENDED' slits = [{'name': 'S200A2'}, {'name': 'S200A1'}, {'name': 'S1600A1'}] for slit in slits: input.slits.append(slit) result = srctype.set_source_type(input) assert (result.slits[0].source_type == 'POINT') assert (result.slits[1].source_type == 'EXTENDED') assert (result.slits[2].source_type == 'POINT')
def compute_world_coordinates(fname, output=None): """ Computes wavelengths, and space coordinates of a NIRSPEC FS or MOS observation after running extract_2d. Parameters ---------- fname : str The name of a file with extracted slits, i.e. the output of extract2d. output : str The name of the output file. If None the root of the input file is used with an extension world_coordinates. Examples -------- >>> compute_world_coordinates('nrs1_fixed_assign_wcs_extract_2d.fits') """ model = datamodels.MultiSlitModel(fname) if model.meta.exposure.type.lower() not in [ 'nrs_fixedslit', 'nrs_msaspec' ]: raise ValueError("Expected a FS or MOS observation," "(EXP_TYPE=NRS_FIXEDSLIT, NRS_MSASPEC)," "got {0}".format(model.meta.exposure.type)) hdulist = fits.HDUList() phdu = fits.PrimaryHDU() phdu.header['filename'] = model.meta.filename phdu.header['data'] = 'world coordinates' hdulist.append(phdu) output_frame = model.slits[0].meta.wcs.available_frames[-1] for slit in model.slits: # slit.x(y)start are 1-based, turn them to 0-based for extraction # xstart, xend = slit.xstart - 1, slit.xstart -1 + slit.xsize # ystart, yend = slit.ystart - 1, slit.ystart -1 + slit.ysize # y, x = np.mgrid[ystart: yend, xstart: xend] x, y = wcstools.grid_from_bounding_box(slit.meta.wcs.bounding_box, step=(1, 1), center=True) ra, dec, lam = slit.meta.wcs(x, y) detector2slit = slit.meta.wcs.get_transform('detector', 'slit_frame') sx, sy, ls = detector2slit(x, y) world_coordinates = np.array([lam, ra, dec, sy]) #, x, y]) imhdu = fits.ImageHDU(data=world_coordinates) imhdu.header['PLANE1'] = 'lambda, microns' imhdu.header['PLANE2'] = '{0}_x, arcsec'.format(output_frame) imhdu.header['PLANE3'] = '{0}_y, arcsec'.format(output_frame) imhdu.header['PLANE4'] = 'slit_y, relative to center (0, 0)' imhdu.header['SLIT'] = slit.name # add the overall subarray offset imhdu.header['CRVAL1'] = slit.xstart - 1 + model.meta.subarray.xstart imhdu.header['CRVAL2'] = slit.ystart - 1 + model.meta.subarray.ystart imhdu.header['CRPIX1'] = 1 imhdu.header['CRPIX2'] = 1 imhdu.header['CTYPE1'] = 'pixel' imhdu.header['CTYPE2'] = 'pixel' hdulist.append(imhdu) if output is not None: base, ext = os.path.splitext(output) if ext != "fits": ext = "fits" if not base.endswith('world_coordinates'): "".join([base, '_world_coordinates']) "".join([base, ext]) else: root = model.meta.filename.split('_') output = "".join([root[0], '_world_coordinates', '.fits']) hdulist.writeto(output, overwrite=True) del hdulist model.close()
def pathtest(step_input_filename, reffile, comparison_filename, writefile=True, show_figs=True, save_figs=False, threshold_diff=1e-7, debug=False): """ This function calculates the difference between the pipeline and the calculated pathloss values. The functions use the output of the compute_world_coordinates.py script. Args: step_input_filename: str, name of the output fits file from the source type step (with full path) reffile: str, path to the pathloss FS reference fits file comparison_filename: str, path to comparison pipeline pathloss file writefile: boolean, if True writes the fits files of the calculated pathloss and difference image. show_figs: boolean, whether to show plots or not save_figs: boolean, save the plots (the 3 plots can be saved or not independently with the function call) function will name the plot by default) threshold_diff: float, threshold difference between pipeline output and comparison file debug: boolean, if true a series of print statements will show on-screen Returns: - 1 plot, if told to save and/or show them. - median_diff: Boolean, True if smaller or equal to threshold. - log_msgs: list, all print statements are captured in this variable """ log_msgs = [] # start the timer pathtest_start_time = time.time() # get info from the rate file header det = fits.getval(step_input_filename, "DETECTOR", 0) msg = 'step_input_filename=' + step_input_filename print(msg) log_msgs.append(msg) exptype = fits.getval(step_input_filename, "EXP_TYPE", 0) grat = fits.getval(step_input_filename, "GRATING", 0) filt = fits.getval(step_input_filename, "FILTER", 0) msg = "path_loss_file --> Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype print(msg) log_msgs.append(msg) is_point_source = True # get the datamodel from the assign_wcs output file extract2d_wcs_file = step_input_filename.replace("srctype.fits", "extract_2d.fits") model = datamodels.MultiSlitModel(extract2d_wcs_file) if writefile: # create the fits list to hold the calculated pathloss values for # each slit hdu0 = fits.PrimaryHDU() outfile = fits.HDUList() outfile.append(hdu0) # create the fits list to hold the image of pipeline-calculated # difference values hdu0 = fits.PrimaryHDU() compfile = fits.HDUList() compfile.append(hdu0) # list to determine if pytest is passed or not total_test_result = [] # loop over the slits sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"] msg = "Now looping through the slits. This may take a while... " print(msg) log_msgs.append(msg) if det == "NRS2": sltname_list.append("S200B1") # but check if data is BOTS if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": sltname_list = ["S1600A1"] # get all the science extensions ps_uni_ext_list = get_ps_uni_extensions(reffile, is_point_source) # get files print("""Checking if files exist & obtaining datamodels. This takes a few minutes...""") if os.path.isfile(comparison_filename): if debug: print('Comparison file does exist.') else: result_msg = 'Comparison file does NOT exist. Skipping pathloss test.' print(result_msg) log_msgs.append(result_msg) result = 'skip' return result, result_msg, log_msgs # get the comparison data model pathloss_pipe = datamodels.open(comparison_filename) # For the moment, the pipeline is using the wrong reference file for slit 400A1, so read file that # re-processed with the right reference file and open corresponding data model pathloss_400a1 = step_input_filename.replace("srctype.fits", "pathloss_400A1.fits") pathloss_pipe_400a1 = datamodels.open(pathloss_400a1) if debug: print('got comparison datamodel!') if os.path.isfile(step_input_filename): if debug: print('Input file does exist.') else: result_msg = 'Input file does NOT exist. Skipping pathloss test.' log_msgs.append(result_msg) result = 'skip' return result, result_msg, log_msgs # get the input data model pl = datamodels.open(step_input_filename) if debug: print('got input datamodel!') # loop through the wavelengths msg = "Looping through the wavelengths... " print(msg) log_msgs.append(msg) slit_val = 0 for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits): slit_val = slit_val + 1 slit_id = pipe_slit.name # with the current reference file, skip S400A1 #if slit_id == "S400A1": # continue print('\nWorking with slitlet ', slit_id) if slit.name == slit_id: msg = """Slitlet name in fits file previous to pathloss and in pathloss output file are the same. """ log_msgs.append(msg) print(msg) else: msg = """* Missmatch of slitlet names in fits file previous to pathloss and in pathloss output file. Skipping test. """ result = 'skip' log_msgs.append(msg) return result, msg, log_msgs # S-flat mode = "FS" if debug: print("grat = ", grat) continue_pl_test = False if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": slit = model continue_pl_test = True else: for slit_in_MultiSlitModel in pl.slits: if slit_in_MultiSlitModel.name == slit_id: slit = slit_in_MultiSlitModel continue_pl_test = True break if not continue_pl_test: continue else: try: if is_point_source is True: ext = ps_uni_ext_list[0][slit_id] print("Retrieved point source extension") elif is_point_source is False: ext = ps_uni_ext_list[1][slit_id] print("WARNING: Retrieved extended source extension") except KeyError: ext = sltname_list.index(slit_id) print("Unable to retrieve extension.") wcs_obj = slit.meta.wcs # get the wavelength x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box, step=(1, 1), center=True) ra, dec, wave = wcs_obj(x, y) # wave is in microns wave_sci = wave * 10**(-6) # microns --> meters # get positions of source in file: slit_x = slit.source_xpos slit_y = slit.source_ypos if debug: print("slit_x, slit_y", slit_x, slit_y) if slit_id == "S400A1": if is_point_source: ext = 1 else: ext = 3 reffile2use = "jwst-nirspec-a400.plrf.fits" else: reffile2use = reffile msg = "Using reference file: " + reffile2use print(msg) log_msgs.append(msg) plcor_ref_ext = fits.getdata(reffile2use, ext) if debug: print("ext:", ext) hdul = fits.open(reffile2use) plcor_ref = hdul[1].data w = wcs.WCS(hdul[1].header) # make cube w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref. shape[1], :plcor_ref.shape[2]] slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0) previous_sci = slit.data pipe_correction = pipe_slit.pathloss if slit_id == "S400A1": for pipe_slit_400a1 in pathloss_pipe_400a1.slits: if pipe_slit_400a1.name == "S400A1": pipe_correction = pipe_slit_400a1.pathloss break else: continue if len(pipe_correction) == 0: print( "Pipeline pathloss correction in datamodel is empty. Skipping testing this slit." ) continue # Set up manually to test correction at nonzero point # slit_x = 0.2 # slit_y = 0.2 # if debug: # print("""WARNING: Using manually set slit_x and slit_y! # The pipeline correction will not use manually set values and # thus the residuals will change # """) correction_array = np.array([]) lambda_array = np.array([]) wave_sci_flat = wave_sci.reshape(wave_sci.size) wave_ref_flat = wave_ref.reshape(wave_ref.size) ref_xy = np.column_stack((slitx_ref.reshape(slitx_ref.size), slity_ref.reshape(slitx_ref.size))) # loop through slices in lambda from reference file shape = 0 for lambda_val in wave_ref_flat: # loop through every lambda value # flattened so that looping works smoothly shape = shape + 1 index = np.where(wave_ref[:, 0, 0] == lambda_val) # index of closest lambda value in reffile to given sci lambda # took index of only the first slice of wave_ref because # the others were repetitive & we got extra indices # take slice where lambda=index: plcor_slice = plcor_ref_ext[index[0][0]].reshape( plcor_ref_ext[index[0][0]].size) # do 2d interpolation to get a single correction factor for each slice corr_val = scipy.interpolate.griddata(ref_xy[:plcor_slice.size], plcor_slice, np.asarray([slit_x, slit_y]), method='linear') # append values from loop to create a vector of correction factors correction_array = np.append(correction_array, corr_val[0]) # map to array with corresponding lambda lambda_array = np.append(lambda_array, lambda_val) # get correction value for each pixel corr_vals = np.interp(wave_sci_flat, lambda_array, correction_array) corr_vals = corr_vals.reshape(wave_sci.shape) corrected_array = previous_sci / corr_vals # set up generals for all the plots font = {'weight': 'normal', 'size': 7} matplotlib.rc('font', **font) # Plots: step_input_filepath = step_input_filename.replace(".fits", "") # my correction values fig = plt.figure() plt.subplot(221) norm = ImageNormalize(corr_vals) plt.imshow(corr_vals, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Calculated Correction') plt.colorbar() # pipe correction plt.subplot(222) norm = ImageNormalize(pipe_correction) plt.imshow(pipe_correction, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.title("Pipeline Correction") plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.colorbar() # residuals (pipe correction - my correction) corr_residuals = pipe_correction - corr_vals plt.subplot(223) norm = ImageNormalize(corr_residuals) plt.imshow(corr_residuals, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Correction residuals') plt.colorbar() # my science data after plt.subplot(224) norm = ImageNormalize(corrected_array) plt.imshow(corrected_array, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.title('My science data after pathloss') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.colorbar() fig.suptitle("FS PS Pathloss Correction Test for slit " + str(slit_id)) if save_figs: plt_name = step_input_filepath + "_Pathloss_test_slitlet_"+str(mode) + "_" + str(slit_id) + "_" + \ str(slit_val) + ".png" plt.savefig(plt_name) print('Figure saved as: ', plt_name) if show_figs: plt.show() elif not save_figs and not show_figs: msg = "Not making plots because both show_figs and save_figs were set to False." if debug: print(msg) log_msgs.append(msg) elif not save_figs: msg = "Not saving plots because save_figs was set to False." if debug: print(msg) log_msgs.append(msg) plt.clf() # create fits file to hold the calculated pathloss for each slit if writefile: msg = "Saving the fits files with the calculated pathloss for each slit..." print(msg) log_msgs.append(msg) # this is the file to hold the image of pipeline-calculated difference values outfile_ext = fits.ImageHDU(corr_vals, name=slit_id) outfile.append(outfile_ext) # this is the file to hold the image of pipeline-calculated difference values compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id) compfile.append(compfile_ext) if corr_residuals[~np.isnan(corr_residuals)].size == 0: msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \ "Test will be set to FAILED." print(msg1) log_msgs.append(msg1) test_result = "FAILED" else: msg = "Calculating statistics... " print(msg) log_msgs.append(msg) corr_residuals = corr_residuals[ np.where((corr_residuals != 999.0) & (corr_residuals < 0.1) & (corr_residuals > -0.1))] # ignore outliers if corr_residuals.size == 0: msg1 = " * Unable to calculate statistics because difference array has all outlier values. " \ "Test will be set to FAILED." print(msg1) log_msgs.append(msg1) test_result = "FAILED" else: stats_and_strings = auxfunc.print_stats(corr_residuals, "Difference", float(threshold_diff), abs=True) stats, stats_print_strings = stats_and_strings corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats for msg in stats_print_strings: log_msgs.append(msg) # This is the key argument for the assert pytest function median_diff = False if abs(corr_residuals_median) <= float(threshold_diff): median_diff = True if median_diff: test_result = "PASSED" else: test_result = "FAILED" msg = " *** Result of the test: " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result.append(test_result) if writefile: outfile_name = step_input_filename.replace("srctype", "calcuated_FS_PS_pathloss") compfile_name = step_input_filename.replace( "srctype", "comparison_FS_PS_pathloss") # create the fits list to hold the calculated flat values for each slit outfile.writeto(outfile_name, overwrite=True) # this is the file to hold the image of pipeline-calculated difference values compfile.writeto(compfile_name, overwrite=True) msg = "\nFits file with calculated pathloss values of each slit saved as: " print(msg) log_msgs.append(msg) print(outfile_name) log_msgs.append(outfile_name) msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: " print(msg) log_msgs.append(msg) print(compfile_name) log_msgs.append(compfile_name) # If all tests passed then pytest is marked as PASSED, else it is FAILED FINAL_TEST_RESULT = False for t in total_test_result: if t == "FAILED": FINAL_TEST_RESULT = False break else: FINAL_TEST_RESULT = True if FINAL_TEST_RESULT: msg = "\n *** Final pathloss test result reported as PASSED *** \n" print(msg) log_msgs.append(msg) result_msg = "All slits PASSED path_loss test." else: msg = "\n *** Final pathloss test result reported as FAILED *** \n" print(msg) log_msgs.append(msg) result_msg = "One or more slits FAILED path_loss test." # end the timer pathloss_end_time = time.time() - pathtest_start_time if pathloss_end_time > 60.0: pathloss_end_time = pathloss_end_time / 60.0 # in minutes pathloss_tot_time = "* Script FS_PS.py took ", repr( pathloss_end_time) + " minutes to finish." if pathloss_end_time > 60.0: pathloss_end_time = pathloss_end_time / 60. # in hours pathloss_tot_time = "* Script FS_PS.py took ", repr( pathloss_end_time) + " hours to finish." else: pathloss_tot_time = "* Script FS_PS.py took ", repr( pathloss_end_time) + " seconds to finish." print(pathloss_tot_time) log_msgs.append(pathloss_tot_time) return FINAL_TEST_RESULT, result_msg, log_msgs
def compare_wcs(infile_name, esa_files_path, msa_conf_name, show_figs=True, save_figs=False, threshold_diff=1.0e-7, mode_used=None, debug=False): """ This function does the WCS comparison from the world coordinates calculated using the pipeline data model with the ESA intermediary files. Args: infile_name: str, name of the output fits file from the assign_wcs step (with full path) esa_files_path: str, full path of where to find all ESA intermediary products to make comparisons for the tests msa_conf_name: str, full path where to find the shutter configuration file show_figs: boolean, whether to show plots or not save_figs: boolean, save the plots or not threshold_diff: float, threshold difference between pipeline output and ESA file mode_used: string, mode used in the PTT configuration file debug: boolean, if true a series of print statements will show on-screen Returns: - plots, if told to save and/or show them. - median_diff: Boolean, True if smaller or equal to threshold - log_msgs: list, all print statements captured in this variable """ log_msgs = [] # get grating and filter info from the rate file header if mode_used is not None and mode_used == "MOS_sim": infile_name = infile_name.replace("assign_wcs", "extract_2d") msg = 'wcs validation test infile_name= ' + infile_name print(msg) log_msgs.append(msg) det = fits.getval(infile_name, "DETECTOR", 0) lamp = fits.getval(infile_name, "LAMP", 0) grat = fits.getval(infile_name, "GRATING", 0) filt = fits.getval(infile_name, "FILTER", 0) msametfl = fits.getval(infile_name, "MSAMETFL", 0) msg = "from assign_wcs file --> Detector: " + det + " Grating: " + grat + " Filter: " + filt + " Lamp: " + lamp print(msg) log_msgs.append(msg) # check that shutter configuration file in header is the same as given in PTT_config file if msametfl != os.path.basename(msa_conf_name): msg = "* WARNING! MSA config file name given in PTT_config file does not match the MSAMETFL keyword in main header.\n" print(msg) log_msgs.append(msg) # copy the MSA shutter configuration file into the pytest directory try: subprocess.run(["cp", msa_conf_name, "."]) except FileNotFoundError: msg1 = " * PTT is not able to locate the MSA shutter configuration file. Please make sure that the msa_conf_name variable in" msg2 = " the PTT_config.cfg file is pointing exactly to where the fits file exists (i.e. full path and name). " msg3 = " -> The WCS test is now set to skip and no plots will be generated. " print(msg1) print(msg2) print(msg3) log_msgs.append(msg1) log_msgs.append(msg2) log_msgs.append(msg3) FINAL_TEST_RESULT = "skip" return FINAL_TEST_RESULT, log_msgs # get shutter info from metadata shutter_info = fits.getdata( msa_conf_name, extname="SHUTTER_INFO") # this is generally ext=2 pslit = shutter_info.field("slitlet_id") quad = shutter_info.field("shutter_quadrant") row = shutter_info.field("shutter_row") col = shutter_info.field("shutter_column") msg = 'Using this MSA shutter configuration file: ' + msa_conf_name print(msg) log_msgs.append(msg) # get the datamodel from the assign_wcs output file if mode_used is None or mode_used != "MOS_sim": img = datamodels.ImageModel(infile_name) # these commands only work for the assign_wcs ouput file # loop over the slits #slits_list = nirspec.get_open_slits(img) # this function returns all open slitlets as defined in msa meta file, # however, some of them may not be projected on the detector, and those are later removed from the list of open # slitlets. To get the open and projected on the detector slitlets we use the following: slits_list = img.meta.wcs.get_transform('gwa', 'slit_frame').slits #print ('Open slits: ', slits_list, '\n') if debug: print("Instrument Configuration") print("Detector: {}".format(img.meta.instrument.detector)) print("GWA: {}".format(img.meta.instrument.grating)) print("Filter: {}".format(img.meta.instrument.filter)) print("Lamp: {}".format(img.meta.instrument.lamp_state)) print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt)) print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt)) print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt)) elif mode_used == "MOS_sim": # this command works for the extract_2d and flat_field output files model = datamodels.MultiSlitModel(infile_name) slits_list = model.slits # list to determine if pytest is passed or not total_test_result = OrderedDict() # loop over the slices for slit in slits_list: name = slit.name msg = "\nWorking with slit: " + str(name) print(msg) log_msgs.append(msg) # get the right index in the list of open shutters pslit_list = pslit.tolist() slitlet_idx = pslit_list.index(int(name)) # Get the ESA trace #raw_data_root_file = "NRSV96215001001P0000000002103_1_491_SE_2016-01-24T01h25m07.cts.fits" # testing only _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file( ) msg = "Using this raw data file to find the corresponding ESA file: " + raw_data_root_file print(msg) log_msgs.append(msg) q, r, c = quad[slitlet_idx], row[slitlet_idx], col[slitlet_idx] msg = "Pipeline shutter info: quadrant= " + str( q) + " row= " + str(r) + " col=" + str(c) print(msg) log_msgs.append(msg) specifics = [q, r, c] esafile = auxfunc.get_esafile(esa_files_path, raw_data_root_file, "MOS", specifics) #esafile = "/Users/pena/Documents/PyCharmProjects/nirspec/pipeline/src/sandbox/zzzz/Trace_MOS_3_319_013_V96215001001P0000000002103_41543_JLAB88.fits" # testing only # skip the test if the esafile was not found if "ESA file not found" in esafile: msg1 = " * compare_wcs_mos.py is exiting because the corresponding ESA file was not found." msg2 = " -> The WCS test is now set to skip and no plots will be generated. " print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) FINAL_TEST_RESULT = "skip" return FINAL_TEST_RESULT, log_msgs # Open the trace in the esafile if len(esafile) == 2: print(len(esafile[-1])) if len(esafile[-1]) == 0: esafile = esafile[0] msg = "Using this ESA file: \n" + str(esafile) print(msg) log_msgs.append(msg) with fits.open(esafile) as esahdulist: print("* ESA file contents ") esahdulist.info() esa_shutter_i = esahdulist[0].header['SHUTTERI'] esa_shutter_j = esahdulist[0].header['SHUTTERJ'] esa_quadrant = esahdulist[0].header['QUADRANT'] if debug: msg = "ESA shutter info: quadrant=" + esa_quadrant + " shutter_i=" + esa_shutter_i + " shutter_j=" + esa_shutter_j print(msg) log_msgs.append(msg) # first check if ESA shutter info is the same as pipeline msg = "For slitlet" + str(name) print(msg) log_msgs.append(msg) if q == esa_quadrant: msg = "\n -> Same quadrant for pipeline and ESA data: " + str( q) print(msg) log_msgs.append(msg) else: msg = "\n -> Missmatch of quadrant for pipeline and ESA data: " + str( q) + esa_quadrant print(msg) log_msgs.append(msg) if r == esa_shutter_i: msg = "\n -> Same row for pipeline and ESA data: " + str(r) print(msg) log_msgs.append(msg) else: msg = "\n -> Missmatch of row for pipeline and ESA data: " + str( r) + esa_shutter_i print(msg) log_msgs.append(msg) if c == esa_shutter_j: msg = "\n -> Same column for pipeline and ESA data: " + str( c) + "\n" print(msg) log_msgs.append(msg) else: msg = "\n -> Missmatch of column for pipeline and ESA data: " + str( c) + esa_shutter_j + "\n" print(msg) log_msgs.append(msg) # Assign variables according to detector skipv2v3test = True if det == "NRS1": try: esa_flux = fits.getdata(esafile, "DATA1") esa_wave = fits.getdata(esafile, "LAMBDA1") esa_slity = fits.getdata(esafile, "SLITY1") esa_msax = fits.getdata(esafile, "MSAX1") esa_msay = fits.getdata(esafile, "MSAY1") pyw = wcs.WCS(esahdulist['LAMBDA1'].header) try: esa_v2v3x = fits.getdata(esafile, "V2V3X1") esa_v2v3y = fits.getdata(esafile, "V2V3Y1") skipv2v3test = False except KeyError: msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions." print(msg) log_msgs.append(msg) except KeyError: msg = "PTT did not find ESA extensions that match detector NRS1, skipping test for this slitlet..." print(msg) log_msgs.append(msg) continue if det == "NRS2": try: esa_flux = fits.getdata(esafile, "DATA2") esa_wave = fits.getdata(esafile, "LAMBDA2") esa_slity = fits.getdata(esafile, "SLITY2") esa_msax = fits.getdata(esafile, "MSAX2") esa_msay = fits.getdata(esafile, "MSAY2") pyw = wcs.WCS(esahdulist['LAMBDA2'].header) try: esa_v2v3x = fits.getdata(esafile, "V2V3X2") esa_v2v3y = fits.getdata(esafile, "V2V3Y2") skipv2v3test = False except KeyError: msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions." print(msg) log_msgs.append(msg) except KeyError: msg = "PTT did not find ESA extensions that match detector NRS2, skipping test for this slitlet..." print(msg) log_msgs.append(msg) continue # get the WCS object for this particular slit if mode_used is None or mode_used != "MOS_sim": try: wcs_slice = nirspec.nrs_wcs_set_input(img, name) except: ValueError msg = "* WARNING: Slitlet " + name + " was not found in the model. Skipping test for this slitlet." print(msg) log_msgs.append(msg) continue elif mode_used == "MOS_sim": wcs_slice = model.slits[0].wcs # if we want to print all available transforms, uncomment line below #print(wcs_slice) # The WCS object attribute bounding_box shows all valid inputs, i.e. the actual area of the data according # to the slice. Inputs outside of the bounding_box return NaN values. #bbox = wcs_slice.bounding_box #print('wcs_slice.bounding_box: ', wcs_slice.bounding_box) # In different observing modes the WCS may have different coordinate frames. To see available frames # uncomment line below. #print("Avalable frames: ", wcs_slice.available_frames) if mode_used is None or mode_used != "MOS_sim": if debug: # To get specific pixel values use following syntax: det2slit = wcs_slice.get_transform('detector', 'slit_frame') slitx, slity, lam = det2slit(700, 1080) print("slitx: ", slitx) print("slity: ", slity) print("lambda: ", lam) if debug: # The number of inputs and outputs in each frame can vary. This can be checked with: print('Number on inputs: ', det2slit.n_inputs) print('Number on outputs: ', det2slit.n_outputs) # Create x, y indices using the Trace WCS pipey, pipex = np.mgrid[:esa_wave.shape[0], :esa_wave.shape[1]] esax, esay = pyw.all_pix2world(pipex, pipey, 0) if det == "NRS2": msg = "NRS2 needs a flip" print(msg) log_msgs.append(msg) esax = 2049 - esax esay = 2049 - esay # Compute pipeline RA, DEC, and lambda slitlet_test_result_list = [] pra, pdec, pwave = wcs_slice( esax - 1, esay - 1 ) # => RETURNS: RA, DEC, LAMBDA (lam *= 10**-6 to convert to microns) pwave *= 10**-6 # calculate and print statistics for slit-y and x relative differences slitlet_name = repr(r) + "_" + repr(c) tested_quantity = "Wavelength Difference" rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_wave, pwave, tested_quantity) rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats_strings = rel_diff_pwave_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_rel_diff_pwave_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) # get the transforms for pipeline slit-y det2slit = wcs_slice.get_transform('detector', 'slit_frame') slitx, slity, _ = det2slit(esax - 1, esay - 1) tested_quantity = "Slit-Y Difference" # calculate and print statistics for slit-y and x relative differences rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=False) # calculate and print statistics for slit-y and x absolute differences #rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=True) rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats_strings = rel_diff_pslity_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_rel_diff_pslity_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) # do the same for MSA x, y and V2, V3 detector2msa = wcs_slice.get_transform("detector", "msa_frame") pmsax, pmsay, _ = detector2msa( esax - 1, esay - 1 ) # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns) # MSA-x tested_quantity = "MSA_X Difference" reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_msax, pmsax, tested_quantity) reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats_strings = reldiffpmsax_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_reldiffpmsax_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) # MSA-y tested_quantity = "MSA_Y Difference" reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_msay, pmsay, tested_quantity) reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats_strings = reldiffpmsay_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_reldiffpmsay_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) # V2 and V3 if not skipv2v3test: detector2v2v3 = wcs_slice.get_transform("detector", "v2v3") pv2, pv3, _ = detector2v2v3( esax - 1, esay - 1 ) # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns) tested_quantity = "V2 difference" # converting to degrees to compare with ESA reldiffpv2_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity) if reldiffpv2_data[-2][0] > 0.0: print( "\nConverting pipeline results to degrees to compare with ESA" ) pv2 = pv2 / 3600. reldiffpv2_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity) reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats_strings = reldiffpv2_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_reldiffpv2_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) tested_quantity = "V3 difference" # converting to degrees to compare with ESA reldiffpv3_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity) if reldiffpv3_data[-2][0] > 0.0: print( "\nConverting pipeline results to degrees to compare with ESA" ) pv3 = pv3 / 3600. reldiffpv3_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity) reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats_strings = reldiffpv3_data for msg in print_stats_strings: log_msgs.append(msg) result = auxfunc.does_median_pass_tes(notnan_reldiffpv3_stats[1], threshold_diff) msg = 'Result for test of ' + tested_quantity + ': ' + result print(msg) log_msgs.append(msg) slitlet_test_result_list.append({tested_quantity: result}) total_test_result[slitlet_name] = slitlet_test_result_list # PLOTS if show_figs or save_figs: # set the common variables basenameinfile_name = os.path.basename(infile_name) main_title = filt + " " + grat + " SLITLET=" + slitlet_name + "\n" bins = 15 # binning for the histograms, if None the function will automatically calculate them # lolim_x, uplim_x, lolim_y, uplim_y plt_origin = None # Wavelength title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{ESA}) / \lambda_{ESA}$", "N" info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats] if notnan_rel_diff_pwave_stats[1] is np.nan: msg = "Unable to create plot of relative wavelength difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_wave_diffs.pdf") auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img, notnan_rel_diff_pwave, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # Slit-y title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{ESA}$)/slit_y$_{ESA}$", "N" info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats] if notnan_rel_diff_pslity_stats[1] is np.nan: msg = "Unable to create plot of relative slit-y difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_slitY_diffs.pdf") auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img, notnan_rel_diff_pslity, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # MSA-x title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{ESA}$)/MSA_x$_{ESA}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats] if notnan_reldiffpmsax_stats[1] is np.nan: msg = "Unable to create plot of relative MSA-x difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_MSAx_diffs.pdf") auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img, notnan_reldiffpmsax, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # MSA-y title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{ESA}$)/MSA_y$_{ESA}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats] if notnan_reldiffpmsay_stats[1] is np.nan: msg = "Unable to create plot of relative MSA-y difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_MSAy_diffs.pdf") auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img, notnan_reldiffpmsay, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) if not skipv2v3test: # V2 title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{ESA}$)/V2$_{ESA}$", "N" hist_data = notnan_reldiffpv2 info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats] if notnan_reldiffpv2_stats[1] is np.nan: msg = "Unable to create plot of relative V2 difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_V2_diffs.pdf") auxfunc.plt_two_2Dimgandhist(reldiffpv2_img, hist_data, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # V3 title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{ESA}$)/V3$_{ESA}$", "N" hist_data = notnan_reldiffpv3 info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats] if notnan_reldiffpv3_stats[1] is np.nan: msg = "Unable to create plot of relative V3 difference." print(msg) log_msgs.append(msg) else: plt_name = infile_name.replace( basenameinfile_name, slitlet_name + "_" + det + "_rel_V3_diffs.pdf") auxfunc.plt_two_2Dimgandhist(reldiffpv3_img, hist_data, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) else: msg = "NO plots were made because show_figs and save_figs were both set to False. \n" print(msg) log_msgs.append(msg) # remove the copy of the MSA shutter configuration file subprocess.run(["rm", msametfl]) # If all tests passed then pytest will be marked as PASSED, else it will be FAILED FINAL_TEST_RESULT = "FAILED" for sl, testlist in total_test_result.items(): for tdict in testlist: for t, tr in tdict.items(): if tr == "FAILED": FINAL_TEST_RESULT = "FAILED" msg = "\n * The test of " + t + " for slitlet " + sl + " FAILED." print(msg) log_msgs.append(msg) else: FINAL_TEST_RESULT = "PASSED" msg = "\n * The test of " + t + " for slitlet " + sl + " PASSED." print(msg) log_msgs.append(msg) if FINAL_TEST_RESULT == "PASSED": msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n" print(msg) log_msgs.append(msg) else: msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n" print(msg) log_msgs.append(msg) return FINAL_TEST_RESULT, log_msgs
def do_correction_fixedslit(data, pathloss, inverse=False, source_type=None, correction_pars=None): """Path loss correction for NIRSpec fixed-slit modes Data is modified in-place. Parameters ---------- data : jwst.datamodel.DataModel The NIRSpec fixed-slit data to be corrected. pathloss : jwst.datamodel.DataModel The pathloss reference data. inverse : boolean Invert the math operations used to apply the flat field. source_type : str or None Force processing using the specified source type. correction_pars : jwst.datamodels.MultiSlitModel or None The precomputed pathloss to apply instead of recalculation. Returns ------- corrections : jwst.datamodel.MultiSlitModel The pathloss corrections applied. """ exp_type = data.meta.exposure.type # Loop over all slits contained in the input corrections = datamodels.MultiSlitModel() for slit_number, slit in enumerate(data.slits): log.info(f'Working on slit {slit.name}') if correction_pars: correction = correction_pars.slits[slit_number] else: correction = _corrections_for_fixedslit(slit, pathloss, exp_type, source_type) corrections.slits.append(correction) # Apply the correction if not correction: log.warning( f'No correction provided for slit {slit_number}. Skipping') continue if not inverse: slit.data /= correction.data else: slit.data *= correction.data slit.err /= correction.data slit.var_poisson /= correction.data**2 slit.var_rnoise /= correction.data**2 if slit.var_flat is not None and np.size(slit.var_flat) > 0: slit.var_flat /= correction.data**2 slit.pathloss_point = correction.pathloss_point slit.pathloss_uniform = correction.pathloss_uniform slit.data /= correction.data slit.err /= correction.data slit.var_poisson /= correction.data**2 slit.var_rnoise /= correction.data**2 if slit.var_flat is not None and np.size(slit.var_flat) > 0: slit.var_flat /= correction.data**2 slit.pathloss_point = correction.pathloss_point slit.pathloss_uniform = correction.pathloss_uniform # Set step status to complete data.meta.cal_step.pathloss = 'COMPLETE' return corrections
def compare_wcs(infile_name, truth_file=None, esa_files_path=None, show_figs=True, save_figs=False, threshold_diff=1.0e-7, raw_data_root_file=None, output_directory=None, debug=False): """ This function does the WCS comparison from the world coordinates calculated using the pipeline data model with the truth files or the ESA intermediary files (to create new truth files). Args: infile_name: str, name of the output fits file from the assign_wcs step (with full path) truth_file: str, full path to the 'truth' (or benchmark) file to compare to esa_files_path: str or None, path to file all the ESA intermediary files show_figs: boolean, whether to show plots or not save_figs: boolean, save the plots or not threshold_diff: float, threshold difference between pipeline output and truth file raw_data_root_file: None or string, name of the raw file that produced the _uncal.fits file for caldetector1 output_directory: None or string, path to the output_directory where to save the plots debug: boolean, if true a series of print statements will show on-screen Returns: - plots, if told to save and/or show them. - median_diff: Boolean, True if smaller or equal to threshold - log_msgs: list, all print statements captured in this variable """ log_msgs = [] if truth_file is not None: # get the model from the "truth" (or comparison) file truth_hdul = fits.open(truth_file) print("Information from the 'truth' (or comparison) file ") print(truth_hdul.info()) truth_hdul.close() truth_img = datamodels.ImageModel(truth_file) # get the open slits in the truth file open_slits_truth = truth_img.meta.wcs.get_transform( 'gwa', 'slit_frame').slits # open the datamodel truth_model = datamodels.MultiSlitModel(truth_file) # determine to what are we comparing to esa_files_path = None print("Comparing to ST 'truth' file.") else: print("Comparing to ESA data") # get grating and filter info from the rate file header if isinstance(infile_name, str): # get the datamodel from the assign_wcs output file img = datamodels.ImageModel(infile_name) msg = 'infile_name=' + infile_name print(msg) log_msgs.append(msg) basenameinfile_name = os.path.basename(infile_name) else: img = infile_name basenameinfile_name = "" det = img.meta.instrument.detector lamp = img.meta.instrument.lamp_state grat = img.meta.instrument.grating filt = img.meta.instrument.filter msg = "from assign_wcs file --> Detector: " + det + " Grating: " + grat + " Filter: " + \ filt + " Lamp: " + lamp print(msg) log_msgs.append(msg) print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt)) print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt)) print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt)) # list to determine if pytest is passed or not total_test_result = OrderedDict() if esa_files_path is not None: # mapping the ESA slit names to pipeline names map_slit_names = { 'SLIT_A_1600': 'S1600A1', 'SLIT_A_200_1': 'S200A1', 'SLIT_A_200_2': 'S200A2', 'SLIT_A_400': 'S400A1', 'SLIT_B_200': 'S200B1', } # To get the open and projected on the detector slits of the pipeline processed file open_slits = img.meta.wcs.get_transform('gwa', 'slit_frame').slits for opslit in open_slits: pipeslit = opslit.name msg = "\nWorking with slit: " + pipeslit print(msg) log_msgs.append(msg) compare_to_esa_data = False if esa_files_path is not None: compare_to_esa_data = True # Get the ESA trace if raw_data_root_file is None: _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file( infile_name) specifics = [pipeslit] # check if ESA data is not in the regular directory tree, these files are exceptions NIDs = ["30055", "30055", "30205", "30133", "30133"] special_cutout_files = [ "NRSSMOS-MOD-G1H-02-5344031756_1_491_SE_2015-12-10T03h25m56.fits", "NRSSMOS-MOD-G1H-02-5344031756_1_492_SE_2015-12-10T03h25m56.fits", "NRSSMOS-MOD-G2M-01-5344191938_1_491_SE_2015-12-10T19h29m26.fits", "NRSSMOS-MOD-G3H-02-5344120942_1_491_SE_2015-12-10T12h18m25.fits", "NRSSMOS-MOD-G3H-02-5344120942_1_492_SE_2015-12-10T12h18m25.fits" ] if raw_data_root_file in special_cutout_files: nid = NIDs[special_cutout_files.index(raw_data_root_file)] msg = "Using NID = " + nid print(msg) log_msgs.append(msg) else: nid = None esafile, esafile_log_msgs = auxfunc.get_esafile(esa_files_path, raw_data_root_file, "FS", specifics, nid=nid) for msg in esafile_log_msgs: log_msgs.append(msg) # skip the test if the esafile was not found if esafile == "ESA file not found": msg1 = " * compare_wcs_fs.py is exiting because the corresponding ESA file was not found." msg2 = " -> The WCS test is now set to skip and no plots will be generated. " print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) FINAL_TEST_RESULT = "skip" return FINAL_TEST_RESULT, log_msgs """ # comparison of filter, grating and x and y tilt gwa_xtil = fits.getval(infile_name, "gwa_xtil", 0) gwa_ytil = fits.getval(infile_name, "gwa_ytil", 0) esagrat = fits.getval(esafile, "GWA_POS", 0) esafilt = fits.getval(esafile, "FWA_POS", 0) esa_xtil = fits.getval(esafile, "gwa_xtil", 0) esa_ytil = fits.getval(esafile, "gwa_ytil", 0) print("pipeline: ") print("grating=", grat, " Filter=", filt, " gwa_xtil=", gwa_xtil, " gwa_ytil=", gwa_ytil) print("ESA:") print("grating=", esagrat, " Filter=", esafilt, " gwa_xtil=", esa_xtil, " gwa_ytil=", esa_ytil) """ # Open the trace in the esafile msg = "Using this ESA file: \n" + esafile print(msg) log_msgs.append(msg) with fits.open(esafile) as esahdulist: print("* ESA file contents ") esahdulist.info() esa_slit_id = map_slit_names[esahdulist[0].header['SLITID']] # first check is esa_slit == to pipe_slit? if pipeslit == esa_slit_id: msg = "\n -> Same slit found for pipeline and ESA data: " + pipeslit + "\n" print(msg) log_msgs.append(msg) else: msg = "\n -> Missmatch of slits for pipeline and ESA data: " + pipeslit, esa_slit_id + "\n" print(msg) log_msgs.append(msg) # Assign variables according to detector skipv2v3test = True if det == "NRS1": try: truth_flux = fits.getdata(esafile, "DATA1") truth_wave = fits.getdata(esafile, "LAMBDA1") truth_slity = fits.getdata(esafile, "SLITY1") truth_msax = fits.getdata(esafile, "MSAX1") truth_msay = fits.getdata(esafile, "MSAY1") pyw = wcs.WCS(esahdulist['LAMBDA1'].header) try: truth_v2 = fits.getdata(esafile, "V2V3X1") truth_v3 = fits.getdata(esafile, "V2V3Y1") skipv2v3test = False except KeyError: msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding " \ "extensions." print(msg) log_msgs.append(msg) except KeyError: msg = "This file does not contain data for detector NRS1. Skipping test for this slit." print(msg) log_msgs.append(msg) continue if det == "NRS2": try: truth_flux = fits.getdata(esafile, "DATA2") truth_wave = fits.getdata(esafile, "LAMBDA2") truth_slity = fits.getdata(esafile, "SLITY2") truth_msax = fits.getdata(esafile, "MSAX2") truth_msay = fits.getdata(esafile, "MSAY2") pyw = wcs.WCS(esahdulist['LAMBDA2'].header) try: truth_v2 = fits.getdata(esafile, "V2V3X2") truth_v3 = fits.getdata(esafile, "V2V3Y2") skipv2v3test = False except KeyError: msg = "Skipping tests for V2 and V3 because ESA file does not contain " \ "corresponding extensions." print(msg) log_msgs.append(msg) except KeyError: msg1 = "\n * compare_wcs_fs.py is exiting because there are no extensions that match " \ "detector NRS2 in the ESA file." msg2 = " -> The WCS test is now set to skip and no plots will be generated. \n" print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) FINAL_TEST_RESULT = "skip" return FINAL_TEST_RESULT, log_msgs # Create x, y indices using the trace WCS from ESA pipey, pipex = np.mgrid[:truth_wave.shape[0], :truth_wave.shape[1]] esax, esay = pyw.all_pix2world(pipex, pipey, 0) if det == "NRS2": esax = 2049 - esax esay = 2049 - esay msg = "Flipped ESA data for detector NRS2 comparison with pipeline." print(msg) log_msgs.append(msg) # remove 1 to start from 0 truth_x, truth_y = esax - 1, esay - 1 # In case we are NOT comparing to ESA data if not compare_to_esa_data: # determine if the current open slit is also open in the truth file continue_with_wcs_test = False for truth_open_slit in open_slits_truth: if truth_open_slit.name == pipeslit: continue_with_wcs_test = True break if not continue_with_wcs_test: msg1 = "\n * Script compare_wcs_fs.py is exiting because open slit "+pipeslit+" is not open in " \ "truth file." msg2 = " -> The WCS test is now set to FAILED and no plots will be generated. \n" print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) FINAL_TEST_RESULT = "FAILED" return FINAL_TEST_RESULT, log_msgs # get the WCS object for this particular truth slit slit_wcs = nirspec.nrs_wcs_set_input(truth_model, pipeslit) truth_x, truth_y = wcstools.grid_from_bounding_box( slit_wcs.bounding_box, step=(1, 1), center=True) truth_ra, truth_dec, truth_wave = slit_wcs( truth_x, truth_y) # wave is in microns truth_wave *= 10**-6 # (lam *= 10**-6 to convert to microns) # get the truths to compare to truth_det2slit = slit_wcs.get_transform('detector', 'slit_frame') truth_slitx, truth_slity, _ = truth_det2slit(truth_x, truth_y) truth_det2slit = slit_wcs.get_transform("detector", "msa_frame") truth_msax, truth_msay, _ = truth_det2slit(truth_x, truth_y) truth_det2slit = slit_wcs.get_transform("detector", "v2v3") truth_v2, truth_v3, _ = truth_det2slit(truth_x, truth_y) skipv2v3test = False # get the WCS object for this particular slit wcs_slit = nirspec.nrs_wcs_set_input(img, pipeslit) # if we want to print all available transforms, uncomment line below #print(wcs_slit) # In different observing modes the WCS may have different coordinate frames. To see available frames # uncomment line below. available_frames = wcs_slit.available_frames print("Avalable frames: ", available_frames) if debug: # To get specific pixel values use following syntax: det2slit = wcs_slit.get_transform('detector', 'slit_frame') slitx, slity, lam = det2slit(700, 1080) print("slitx: ", slitx) print("slity: ", slity) print("lambda: ", lam) # The number of inputs and outputs in each frame can vary. This can be checked with: print('Number on inputs: ', det2slit.n_inputs) print('Number on outputs: ', det2slit.n_outputs) # check if subarray is not FULL FRAME subarray = img.meta.subarray if "FULL" not in subarray: # In subarray coordinates # subtract xstart and ystart values in order to get subarray coords instead of full frame # wcs_slit.x(y)start are 1-based, turn them to 0-based for extraction xstart, ystart = img.meta.subarray.xstart, img.meta.subarray.ystart bounding_box = False else: bounding_box = True pra, pdec, pwave = wcs_slit( truth_x, truth_y, with_bounding_box=bounding_box) # => RETURNS: RA, DEC, LAMBDA pwave *= 10**-6 # (lam *= 10**-6 to convert to microns) # calculate and print statistics for slit-y and x relative differences tested_quantity = "Wavelength Difference" rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_wave, pwave, tested_quantity) rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats = rel_diff_pwave_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_rel_diff_pwave_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} # get the transforms for pipeline slit-y det2slit = wcs_slit.get_transform('detector', 'slit_frame') slitx, slity, _ = det2slit(truth_x, truth_y, with_bounding_box=bounding_box) tested_quantity = "Slit-Y Difference" # calculate and print statistics for slit-y and x relative differences rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_slity, slity, tested_quantity, absolute=False) rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats = rel_diff_pslity_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_rel_diff_pslity_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} # do the same for MSA x, y and V2, V3 detector2msa = wcs_slit.get_transform("detector", "msa_frame") pmsax, pmsay, _ = detector2msa(truth_x, truth_y, with_bounding_box=bounding_box) # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns) # MSA-x tested_quantity = "MSA_X Difference" reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_msax, pmsax, tested_quantity) reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats = reldiffpmsax_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_reldiffpmsax_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} # MSA-y tested_quantity = "MSA_Y Difference" reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_msay, pmsay, tested_quantity) reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats = reldiffpmsay_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_reldiffpmsay_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} # V2 and V3 if not skipv2v3test and 'v2v3' in available_frames: detector2v2v3 = wcs_slit.get_transform("detector", "v2v3") pv2, pv3, _ = detector2v2v3(truth_x, truth_y, with_bounding_box=bounding_box) # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns) tested_quantity = "V2 difference" reldiffpv2_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_v2, pv2, tested_quantity) # converting to degrees to compare with truth, pipeline is in arcsec if reldiffpv2_data[-2][0] > 0.0: print( "\nConverting pipeline results to degrees to compare with truth file" ) pv2 = pv2 / 3600. reldiffpv2_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_v2, pv2, tested_quantity) reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats = reldiffpv2_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_reldiffpv2_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} tested_quantity = "V3 difference" reldiffpv3_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_v3, pv3, tested_quantity) # converting to degrees to compare with truth, pipeline is in arcsec if reldiffpv3_data[-2][0] > 0.0: print( "\nConverting pipeline results to degrees to compare with truth file" ) pv3 = pv3 / 3600. reldiffpv3_data = auxfunc.get_reldiffarr_and_stats( threshold_diff, truth_slity, truth_v3, pv3, tested_quantity) reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats = reldiffpv3_data for msg in print_stats: log_msgs.append(msg) test_result = auxfunc.does_median_pass_tes( notnan_reldiffpv3_stats[1], threshold_diff) msg = "\n * Result of the test for " + tested_quantity + ": " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result[pipeslit] = {tested_quantity: test_result} # PLOTS if show_figs or save_figs: # set the common variables main_title = filt + " " + grat + " SLIT=" + pipeslit + "\n" bins = 15 # binning for the histograms, if None the function will automatically calculate number # lolim_x, uplim_x, lolim_y, uplim_y plt_origin = None # Wavelength title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{truth}) / \lambda_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats] if notnan_rel_diff_pwave_stats[1] is np.nan: msg = "Unable to create plot of relative wavelength difference." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_wave_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = os.path.join( os.getcwd(), pipeslit + "_" + det + specific_plt_name) print( "No output_directory was provided. Figures will be saved in current working directory:" ) print(plt_name + "\n") auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img, notnan_rel_diff_pwave, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # Slit-y title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{truth}$)/slit_y$_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats] if notnan_rel_diff_pslity_stats[1] is np.nan: msg = "Unable to create plot of relative slit position." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_slitY_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = None save_figs = False print( "No output_directory was provided. Figures will NOT be saved." ) auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img, notnan_rel_diff_pslity, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # MSA-x title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{truth}$)/MSA_x$_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats] if notnan_reldiffpmsax_stats[1] is np.nan: msg = "Unable to create plot of relative MSA-x difference." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_MSAx_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = None save_figs = False print( "No output_directory was provided. Figures will NOT be saved." ) auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img, notnan_reldiffpmsax, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # MSA-y title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{truth}$)/MSA_y$_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats] if notnan_reldiffpmsay_stats[1] is np.nan: msg = "Unable to create plot of relative MSA-y difference." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_MSAy_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = None save_figs = False print( "No output_directory was provided. Figures will NOT be saved." ) auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img, notnan_reldiffpmsay, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) if not skipv2v3test and 'v2v3' in available_frames: # V2 title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{truth}$)/V2$_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats] if notnan_reldiffpv2_stats[1] is np.nan: msg = "Unable to create plot of relative V2 difference." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_V2_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = None save_figs = False print( "No output_directory was provided. Figures will NOT be saved." ) auxfunc.plt_two_2Dimgandhist(reldiffpv2_img, notnan_reldiffpv2_stats, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) # V3 title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{truth}$)/V3$_{truth}$", "N" info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats] if notnan_reldiffpv3_stats[1] is np.nan: msg = "Unable to create plot of relative V3 difference." print(msg) log_msgs.append(msg) else: specific_plt_name = "_rel_V3_diffs.png" if isinstance(infile_name, str): plt_name = infile_name.replace( basenameinfile_name, pipeslit + "_" + det + specific_plt_name) else: if output_directory is not None: plt_name = os.path.join( output_directory, pipeslit + "_" + det + specific_plt_name) else: plt_name = None save_figs = False print( "No output_directory was provided. Figures will NOT be saved." ) auxfunc.plt_two_2Dimgandhist(reldiffpv3_img, notnan_reldiffpv3, info_img, info_hist, plt_name=plt_name, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) else: msg = "NO plots were made because show_figs and save_figs were both set to False. \n" print(msg) log_msgs.append(msg) # If all tests passed then pytest will be marked as PASSED, else it will be FAILED FINAL_TEST_RESULT = "FAILED" for sl, testdir in total_test_result.items(): for t, tr in testdir.items(): if tr == "FAILED": FINAL_TEST_RESULT = "FAILED" msg = "\n * The test of " + t + " for slit " + sl + " FAILED." print(msg) log_msgs.append(msg) else: FINAL_TEST_RESULT = "PASSED" msg = "\n * The test of " + t + " for slit " + sl + " PASSED." print(msg) log_msgs.append(msg) if FINAL_TEST_RESULT == "PASSED": msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n" else: msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n" print(msg) log_msgs.append(msg) return FINAL_TEST_RESULT, log_msgs
def do_correction(input_model, barshadow_model=None, inverse=False, source_type=None, correction_pars=None): """Do the Bar Shadow Correction Parameters ---------- input_model : `~jwst.datamodels.MultiSlitModel` science data model to be corrected barshadow_model : `~jwst.datamodels.BarshadowModel` bar shadow data model from reference file inverse : boolean Invert the math operations used to apply the flat field. source_type : str or None Force processing using the specified source type. correction_pars : dict or None Correction parameters to use instead of recalculation. Returns ------- output_model, corrections : `~jwst.datamodels.MultiSlitModel`, jwst.datamodels.DataModel Science data model with correction applied and barshadow extensions added, and a model of the correction arrays. """ # Input is a MultiSlitModel science data model. # A MultislitModel has a member ".slits" that behaves like # a list of Slits, each of which has several data arrays and # keyword metadata items associated. # # At this point we're going to have to assume that a Slit is composed # of a set of slit[].nshutters in a vertical line. # # Reference file information is a 1x1 ref file and a 1x3 ref file. # Both the 1x1 and 1x3 ref files are 1001 pixels high and go from # -1 to 1 in their Y value WCS. exp_type = input_model.meta.exposure.type log.debug('EXP_TYPE = %s' % exp_type) # Create output as a copy of the input science data model output_model = input_model.copy() # Loop over all the slits in the input model corrections = datamodels.MultiSlitModel() for slit_idx, slitlet in enumerate(output_model.slits): slitlet_number = slitlet.slitlet_id log.info('Working on slitlet %d' % slitlet_number) if correction_pars: correction = correction_pars.slits[slit_idx] else: correction = _calc_correction(slitlet, barshadow_model, source_type) corrections.slits.append(correction) # Apply the correction by dividing into the science and uncertainty arrays: # var_poisson and var_rnoise are divided by correction**2, # because they're variance, while err is standard deviation if not inverse: slitlet.data /= correction.data else: slitlet.data *= correction.data slitlet.err /= correction.data slitlet.var_poisson /= correction.data**2 slitlet.var_rnoise /= correction.data**2 if slitlet.var_flat is not None and np.size(slitlet.var_flat) > 0: slitlet.var_flat /= correction.data**2 slitlet.barshadow = correction.data return output_model, corrections
def flattest(step_input_filename, dflatref_path=None, sfile_path=None, fflat_path=None, msa_shutter_conf=None, writefile=False, show_figs=True, save_figs=False, plot_name=None, threshold_diff=1.0e-14, debug=False): """ This function does the WCS comparison from the world coordinates calculated using the compute_world_coordinates.py script with the ESA files. The function calls that script. Args: step_input_filename: str, name of the output fits file from the 2d_extract step (with full path) dflatref_path: str, path of where the D-flat reference fits files sfile_path: str, path of where the S-flat reference fits files fflat_path: str, path of where the F-flat reference fits files msa_shutter_conf: str, full path and name of the MSA configuration fits file writefile: boolean, if True writes the fits files of the calculated flat and difference images show_figs: boolean, whether to show plots or not save_figs: boolean, save the plots (the 3 plots can be saved or not independently with the function call) plot_name: string, desired name (if name is not given, the plot function will name the plot by default) threshold_diff: float, threshold difference between pipeline output and ESA file debug: boolean, if true a series of print statements will show on-screen Returns: - 1 plot, if told to save and/or show. - median_diff: Boolean, True if smaller or equal to 1e-14 - log_msgs: list, all print statements are captured in this variable """ log_msgs = [] # start the timer flattest_start_time = time.time() # get info from the rate file header det = fits.getval(step_input_filename, "DETECTOR", 0) msg = 'step_input_filename=' + step_input_filename print(msg) log_msgs.append(msg) lamp = fits.getval(step_input_filename, "LAMP", 0) exptype = fits.getval(step_input_filename, "EXP_TYPE", 0) grat = fits.getval(step_input_filename, "GRATING", 0) filt = fits.getval(step_input_filename, "FILTER", 0) msg = "rate_file --> Grating:" + grat + " Filter:" + filt + " Lamp:" + lamp print(msg) log_msgs.append(msg) # read in the on-the-fly flat image flatfile = step_input_filename.replace("flat_field.fits", "interpolatedflat.fits") # get the reference files # D-Flat dflat_ending = "f_01.03.fits" dfile = "_".join((dflatref_path, "nrs1", dflat_ending)) if det == "NRS2": dfile = dfile.replace("nrs1", "nrs2") msg = "Using D-flat: " + dfile print(msg) log_msgs.append(msg) dfim = fits.getdata(dfile, "SCI") #1) dfimdq = fits.getdata(dfile, "DQ") #4) # need to flip/rotate the image into science orientation ns = np.shape(dfim) dfim = np.transpose(dfim, ( 0, 2, 1)) # keep in mind that 0,1,2 = z,y,x in Python, whereas =x,y,z in IDL dfimdq = np.transpose(dfimdq) if det == "NRS2": # rotate science data by 180 degrees for NRS2 dfim = dfim[..., ::-1, ::-1] dfimdq = dfimdq[..., ::-1, ::-1] naxis3 = fits.getval(dfile, "NAXIS3", "SCI") # get the wavelength values dfwave = np.array([]) for i in range(naxis3): keyword = "_".join(("PFLAT", str(i + 1))) dfwave = np.append(dfwave, fits.getval(dfile, keyword, "SCI")) dfrqe = fits.getdata(dfile, 2) # S-flat tsp = exptype.split("_") mode = tsp[1] if filt == "F070LP": flat = "FLAT4" elif filt == "F100LP": flat = "FLAT1" elif filt == "F170LP": flat = "FLAT2" elif filt == "F290LP": flat = "FLAT3" elif filt == "CLEAR": flat = "FLAT5" else: msg = "No filter correspondence. Exiting the program." print(msg) log_msgs.append(msg) # This is the key argument for the assert pytest function result_msg = "Test skiped because there is no flat correspondance for the filter in the data: {}".format( filt) median_diff = "skip" return median_diff, result_msg, log_msgs sflat_ending = "f_01.01.fits" sfile = "_".join((sfile_path, grat, "OPAQUE", flat, "nrs1", sflat_ending)) msg = "Using S-flat: " + sfile print(msg) log_msgs.append(msg) if det == "NRS2": sfile = sfile.replace("nrs1", "nrs2") sfim = fits.getdata(sfile, "SCI") #1) sfimdq = fits.getdata(sfile, "DQ") #3) # need to flip/rotate image into science orientation sfim = np.transpose(sfim, (0, 2, 1)) sfimdq = np.transpose(sfimdq, (0, 2, 1)) if det == "NRS2": # rotate science data by 180 degrees for NRS2 sfim = sfim[..., ::-1, ::-1] sfimdq = sfimdq[..., ::-1, ::-1] # get the wavelength values for sflat cube sfimwave = np.array([]) naxis3 = fits.getval(sfile, "NAXIS3", "SCI") for i in range(0, naxis3): if i + 1 < 10: keyword = "".join(("FLAT_0", str(i + 1))) else: keyword = "".join(("FLAT_", str(i + 1))) #print("S-flat -> using ", keyword) try: sfimwave = np.append(sfimwave, fits.getval(sfile, keyword, "SCI")) except: KeyError sfv = fits.getdata(sfile, 5) # F-Flat #print("F-flat -> using the following flats: ") fflat_ending = "_01.01.fits" ffile = fflat_path + "_" + filt + fflat_ending msg = "Using F-flat: " + ffile print(msg) log_msgs.append(msg) ffsq1 = fits.getdata(ffile, "SCI_Q1") #1) naxis3 = fits.getval(ffile, "NAXIS3", "SCI_Q1") #1) ffswaveq1 = np.array([]) for i in range(0, naxis3): if i <= 9: suff = "".join(("0", str(i))) else: suff = str(i) t = ("FLAT", suff) keyword = "_".join(t) #print("1. F-flat -> ", keyword) ffswaveq1 = np.append(ffswaveq1, fits.getval(ffile, keyword, "SCI_Q1")) ffserrq1 = fits.getdata(ffile, "ERR_Q1") #2) ffsdqq1 = fits.getdata(ffile, "DQ_Q1") #3) ffvq1 = fits.getdata(ffile, "Q1") #4) ffsq2 = fits.getdata(ffile, "SCI_Q2") ffswaveq2 = np.array([]) for i in range(0, naxis3): if i <= 9: suff = "".join(("0", str(i))) else: suff = str(i) t = ("FLAT", suff) keyword = "_".join(t) #print("2. F-flat -> using ", keyword) ffswaveq2 = np.append(ffswaveq2, fits.getval(ffile, keyword, "SCI_Q2")) ffserrq2 = fits.getdata(ffile, "ERR_Q2") ffsdqq2 = fits.getdata(ffile, "DQ_Q2") ffvq2 = fits.getdata(ffile, "Q2") ffsq3 = fits.getdata(ffile, "SCI_Q3") ffswaveq3 = np.array([]) for i in range(0, naxis3): if i <= 9: suff = "".join(("0", str(i))) else: suff = str(i) t = ("FLAT", suff) keyword = "_".join(t) #print("3. F-flat -> using ", keyword) ffswaveq3 = np.append(ffswaveq3, fits.getval(ffile, keyword, "SCI_Q3")) ffserrq3 = fits.getdata(ffile, "ERR_Q3") ffsdqq3 = fits.getdata(ffile, "DQ_Q3") ffvq3 = fits.getdata(ffile, "Q3") ffsq4 = fits.getdata(ffile, "SCI_Q4") ffswaveq4 = np.array([]) for i in range(0, naxis3): if i <= 9: suff = "0" + str(i) else: suff = str(i) keyword = "FLAT_" + suff #print("4. F-flat -> using ", keyword) ffswaveq4 = np.append(ffswaveq4, fits.getval(ffile, keyword, "SCI_Q4")) ffserrq4 = fits.getdata(ffile, "ERR_Q4") ffsdqq4 = fits.getdata(ffile, "DQ_Q4") ffvq4 = fits.getdata(ffile, "Q4") # now go through each pixel in the test data if writefile: # create the fits list to hold the image of the correction values hdu0 = fits.PrimaryHDU() outfile = fits.HDUList() outfile.append(hdu0) # create the fits list to hold the image of the comparison values hdu0 = fits.PrimaryHDU() complfile = fits.HDUList() complfile.append(hdu0) # list to determine if pytest is passed or not total_test_result = [] # get the datamodel from the assign_wcs output file extract2d_file = step_input_filename.replace("_flat_field.fits", "_extract_2d.fits") model = datamodels.MultiSlitModel(extract2d_file) # get all the science extensions in the flatfile sci_ext_list = auxfunc.get_sci_extensions(flatfile) # loop over the 2D subwindows and read in the WCS values for slit in model.slits: slit_id = slit.name msg = "\nWorking with slit: " + slit_id print(msg) log_msgs.append(msg) ext = sci_ext_list[ slit_id] # this is for getting the science extension in the pipeline calculated flat # make sure that the slitlet is open and projected on the detector, otherwise indicate so #if not slit_id in open_and_on_detector_slits_list: # print("* This open slitlet was removed because it is not projected on the detector. Test skipped for this slitlet. \n") # continue # get the wavelength y, x = np.mgrid[:slit.data.shape[0], :slit.data.shape[1]] ra, dec, wave = slit.meta.wcs(x, y) # wave is in microns # get the subwindow origin px0 = slit.xstart - 1 + model.meta.subarray.xstart py0 = slit.ystart - 1 + model.meta.subarray.ystart msg = " Subwindow origin: px0=" + repr(px0) + " py0=" + repr(py0) print(msg) log_msgs.append(msg) n_p = np.shape(wave) nw = n_p[0] * n_p[1] nw1, nw2 = n_p[1], n_p[ 0] # remember that x=nw1 and y=nw2 are reversed in Python if debug: print("nw = ", nw) delf = np.zeros([nw2, nw1]) + 999.0 flatcor = np.zeros([nw2, nw1]) + 999.0 # get the slitlet info, needed for the F-Flat ext_shutter_info = "SHUTTER_INFO" # this is extension 2 of the msa file, that has the shutter info slitlet_info = fits.getdata(msa_shutter_conf, ext_shutter_info) sltid = slitlet_info.field("SLITLET_ID") for j, s in enumerate(sltid): if s == int(slit_id): im = j # get the shutter with the source in it if slitlet_info.field("BACKGROUND")[im] == "N": isrc = j # changes suggested by Phil Hodge quad = slit.quadrant #slitlet_info.field("SHUTTER_QUADRANT")[isrc] row = slit.xcen #slitlet_info.field("SHUTTER_ROW")[isrc] col = slit.ycen #slitlet_info.field("SHUTTER_COLUMN")[isrc] slitlet_id = repr(row) + "_" + repr(col) msg = 'silt_id=' + repr(slit_id) + " quad=" + repr( quad) + " row=" + repr(row) + " col=" + repr( col) + " slitlet_id=" + repr(slitlet_id) print(msg) log_msgs.append(msg) # get the relevant F-flat reference data if quad == 1: ffsall = ffsq1 ffsallwave = ffswaveq1 ffsalldq = ffsdqq1 ffv = ffvq1 if quad == 2: ffsall = ffsq2 ffsallwave = ffswaveq2 ffsalldq = ffsdqq2 ffv = ffvq2 if quad == 3: ffsall = ffsq3 ffsallwave = ffswaveq3 ffsalldq = ffsdqq3 ffv = ffvq3 if quad == 4: ffsall = ffsq4 ffsallwave = ffswaveq4 ffsalldq = ffsdqq4 ffv = ffvq4 # loop through the pixels msg = "Now looping through the pixels, this will take a while ... " print(msg) log_msgs.append(msg) wave_shape = np.shape(wave) for j in range(nw1): # in x for k in range(nw2): # in y if np.isfinite(wave[k, j]): # skip if wavelength is NaN # get the pixel indeces jwav = wave[k, j] pind = [k + py0 - 1, j + px0 - 1] if debug: print('j, k, jwav, px0, py0 : ', j, k, jwav, px0, py0) print('pind = ', pind) # get the pixel bandwidth if (j != 0) and (j < nw1 - 1): if np.isfinite(wave[k, j + 1]) and np.isfinite( wave[k, j - 1]): delw = 0.5 * (wave[k, j + 1] - wave[k, j - 1]) if np.isfinite(wave[k, j + 1]) and not np.isfinite( wave[k, j - 1]): delw = wave[k, j + 1] - wave[k, j] if not np.isfinite(wave[k, j + 1]) and np.isfinite( wave[k, j - 1]): delw = wave[k, j] - wave[k, j - 1] if j == 0: delw = wave[k, j + 1] - wave[k, j] if j == nw - 1: delw = wave[k, j] - wave[k, j - 1] if debug: print("wave[k, j+1], wave[k, j-1] : ", np.isfinite(wave[k, j + 1]), wave[k, j + 1], wave[k, j - 1]) print("delw = ", delw) # integrate over dflat fast vector dfrqe_wav = dfrqe.field("WAVELENGTH") dfrqe_rqe = dfrqe.field("RQE") iw = np.where((dfrqe_wav >= wave[k, j] - delw / 2.0) & (dfrqe_wav <= wave[k, j] + delw / 2.0)) if np.size(iw) == 0: dff = 1.0 else: int_tab = auxfunc.idl_tabulate(dfrqe_wav[iw[0]], dfrqe_rqe[iw[0]]) first_dfrqe_wav, last_dfrqe_wav = dfrqe_wav[ iw[0]][0], dfrqe_wav[iw[0]][-1] dff = int_tab / (last_dfrqe_wav - first_dfrqe_wav) if debug: #print("np.shape(dfrqe_wav) : ", np.shape(dfrqe_wav)) #print("np.shape(dfrqe_rqe) : ", np.shape(dfrqe_rqe)) #print("dfimdq[pind[0]][pind[1]] : ", dfimdq[pind[0]][pind[1]]) #print("np.shape(iw) =", np.shape(iw)) #print("np.shape(dfrqe_wav[iw[0]]) = ", np.shape(dfrqe_wav[iw[0]])) #print("np.shape(dfrqe_rqe[iw[0]]) = ", np.shape(dfrqe_rqe[iw[0]])) #print("int_tab=", int_tab) print("dff = ", dff) # interpolate over dflat cube iloc = auxfunc.idl_valuelocate(dfwave, wave[k, j])[0] if dfwave[iloc] > wave[k, j]: iloc -= 1 ibr = [iloc] if iloc != len(dfwave) - 1: ibr.append(iloc + 1) # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0] zz = dfim[:, pind[0], pind[1]][ibr] # now determine the length of the array with only the finite numbers zzwherenonan = np.where(np.isfinite(zz)) kk = np.size(zzwherenonan) dfs = 1.0 if (wave[k, j] <= max(dfwave)) and ( wave[k, j] >= min(dfwave)) and (kk == 2): dfs = np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan]) # check DQ flags if dfimdq[pind[0], pind[1]] != 0: dfs = 1.0 # integrate over S-flat fast vector sfv_wav = sfv.field("WAVELENGTH") sfv_dat = sfv.field("DATA") iw = np.where((sfv_wav >= wave[k, j] - delw / 2.0) & (sfv_wav <= wave[k, j] + delw / 2.0)) sff = 1.0 if np.size(iw) > 2: int_tab = auxfunc.idl_tabulate(sfv_wav[iw], sfv_dat[iw]) first_sfv_wav, last_sfv_wav = sfv_wav[ iw[0]][0], sfv_wav[iw[0]][-1] sff = int_tab / (last_sfv_wav - first_sfv_wav) # interpolate s-flat cube iloc = auxfunc.idl_valuelocate(sfimwave, wave[k, j])[0] ibr = [iloc] if iloc != len(sfimwave) - 1: ibr.append(iloc + 1) # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0] zz = sfim[:, pind[0], pind[1]][ibr] # now determine the length of the array with only the finite numbers zzwherenonan = np.where(np.isfinite(zz)) kk = np.size(zzwherenonan) sfs = 1.0 if (wave[k, j] <= max(sfimwave)) and ( wave[k, j] >= min(sfimwave)) and (kk == 2): sfs = np.interp(wave[k, j], sfimwave[ibr], zz[zzwherenonan]) # check DQ flags kk = np.where(sfimdq[:, pind[0], pind[1]][ibr] == 0) if np.size(kk) != 2: sfs = 1.0 # integrate over f-flat fast vector # reference file wavelength range is from 0.6 to 5.206 microns, so need to force # solution to 1 for wavelengths outside that range ffv_wav = ffv.field("WAVELENGTH") ffv_dat = ffv.field("DATA") fff = 1.0 if (wave[k, j] - delw / 2.0 >= 0.6) and (wave[k, j] + delw / 2.0 <= 5.206): iw = np.where((ffv_wav >= wave[k, j] - delw / 2.0) & (ffv_wav <= wave[k, j] + delw / 2.0)) if np.size(iw) > 1: int_tab = auxfunc.idl_tabulate( ffv_wav[iw], ffv_dat[iw]) first_ffv_wav, last_ffv_wav = ffv_wav[ iw[0]][0], ffv_wav[iw[0]][-1] fff = int_tab / (last_ffv_wav - first_ffv_wav) # interpolate over f-flat cube ffs = np.interp(wave[k, j], ffsallwave, ffsall[:, col - 1, row - 1]) flatcor[k, j] = dff * dfs * sff * sfs * fff * ffs if (pind[1] - px0 + 1 == 9999) and (pind[0] - py0 + 1 == 9999): if debug: print("pind = ", pind) print("wave[k, j] = ", wave[k, j]) print("dfs, dff = ", dfs, dff) print("sfs, sff = ", sfs, sff) msg = "Making the plot fot this slitlet..." print(msg) log_msgs.append(msg) # make plot font = { #'family' : 'normal', 'weight': 'normal', 'size': 16 } matplotlib.rc('font', **font) fig = plt.figure(1, figsize=(12, 10)) plt.subplots_adjust(hspace=.4) ax = plt.subplot(111) xmin = wave[k, j] - 0.01 xmax = wave[k, j] + 0.01 plt.xlim(xmin, xmax) plt.plot(dfwave, dfim[:, pind[0], pind[1]], linewidth=7, marker='D', color='k', label="dflat_im") plt.plot(wave[k, j], dfs, linewidth=7, marker='D', color='r') plt.plot(dfrqe_wav, dfrqe_rqe, linewidth=7, marker='D', c='k', label="dflat_vec") plt.plot(wave[k, j], dff, linewidth=7, marker='D', color='r') plt.plot(sfimwave, sfim[:, pind[0], pind[1]], linewidth=7, marker='D', color='k', label="sflat_im") plt.plot(wave[k, j], sfs, linewidth=7, marker='D', color='r') plt.plot(sfv_wav, sfv_dat, linewidth=7, marker='D', color='k', label="sflat_vec") plt.plot(wave[k, j], sff, linewidth=7, marker='D', color='r') # add legend box = ax.get_position() ax.set_position( [box.x0, box.y0, box.width * 1.0, box.height]) ax.legend(loc='upper right', bbox_to_anchor=(1, 1)) plt.minorticks_on() plt.tick_params(axis='both', which='both', bottom=True, top=True, right=True, direction='in', labelbottom=True) plt.show() msg = "Exiting the program. Unable to calculate statistics. Test set to be SKIPPED." print(msg) log_msgs.append(msg) plt.close() result_msg = "Unable to calculate statistics. Test set be SKIP." median_diff = "skip" return median_diff, result_msg, log_msgs if debug: print("dfs = ", dfs) print("sff = ", sff) print("sfs = ", sfs) print("ffs = ", ffs) # read the pipeline-calculated flat image # there are four extensions in the flatfile: SCI, DQ, ERR, WAVELENGTH pipeflat = fits.getdata(flatfile, ext) try: # Difference between pipeline and calculated values delf[k, j] = pipeflat[k, j] - flatcor[k, j] # Remove all pixels with values=1 (outside slit boundaries) for statistics if pipeflat[k, j] == 1: delf[k, j] = 999.0 if np.isnan(wave[k, j]): flatcor[k, j] = 1.0 # no correction if no wavelength if debug: print("flatcor[k, j] = ", flatcor[k, j]) print("delf[k, j] = ", delf[k, j]) except: IndexError nanind = np.isnan(delf) # get all the nan indexes notnan = ~nanind # get all the not-nan indexes delf = delf[notnan] # get rid of NaNs delf_shape = np.shape(delf) test_result = "FAILED" if delf.size == 0: msg1 = " * Unable to calculate statistics because difference array has all values as NaN." msg2 = " Test will be set to FAILED and NO plots will be made." print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) else: msg = "Calculating statistics... " print(msg) log_msgs.append(msg) delfg = delf[np.where((delf != 999.0) & (delf < 0.1) & (delf > -0.1))] # ignore outliers if delfg.size == 0: msg1 = " * Unable to calculate statistics because difference array has all outlier values." msg2 = " Test will be set to FAILED and NO plots will be made." print(msg1) print(msg2) log_msgs.append(msg1) log_msgs.append(msg2) else: stats_and_strings = auxfunc.print_stats(delfg, "Flat Difference", float(threshold_diff), abs=True) stats, stats_print_strings = stats_and_strings delfg_mean, delfg_median, delfg_std = stats # This is the key argument for the assert pytest function median_diff = False if abs(delfg_median) <= float(threshold_diff): median_diff = True if median_diff: test_result = "PASSED" else: test_result = "FAILED" if save_figs or show_figs: # make histogram msg = "Making histogram plot for this slitlet..." print(msg) log_msgs.append(msg) # set the plot variables main_title = filt + " " + grat + " SLIT=" + slit_id + "\n" bins = None # binning for the histograms, if None the function will select them automatically # lolim_x, uplim_x, lolim_y, uplim_y plt_origin = None # Residuals img and histogram title = main_title + "Residuals" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = "flat$_{pipe}$ - flat$_{calc}$", "N" info_hist = [xlabel, ylabel, bins, stats] if delfg[1] is np.nan: msg = "Unable to create plot of relative wavelength difference." print(msg) log_msgs.append(msg) else: file_path = step_input_filename.replace( os.path.basename(step_input_filename), "") file_basename = os.path.basename( step_input_filename.replace(".fits", "")) t = (file_basename, "MOS_flattest_" + slitlet_id + "_histogram.pdf") plt_name = "_".join(t) plt_name = os.path.join(file_path, plt_name) difference_img = (pipeflat - flatcor) #/flatcor in_slit = np.logical_and( difference_img < 900.0, difference_img > -900.0) # ignore out of slitlet difference_img[ ~in_slit] = np.nan # Set values outside the slit to NaN nanind = np.isnan( difference_img) # get all the nan indexes difference_img[ nanind] = np.nan # set all nan indexes to have a value of nan vminmax = [ -5 * delfg_std, 5 * delfg_std ] # set the range of values to be shown in the image, will affect color scale auxfunc.plt_two_2Dimgandhist(difference_img, delfg, info_img, info_hist, plt_name=plt_name, vminmax=vminmax, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) elif not save_figs and not show_figs: msg = "Not making plots because both show_figs and save_figs were set to False." print(msg) log_msgs.append(msg) elif not save_figs: msg = "Not saving plots because save_figs was set to False." print(msg) log_msgs.append(msg) msg = " *** Result of the test: " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result.append(test_result) # create fits file to hold the calculated flat for each slit if writefile: # this is the file to hold the image of the correction values outfile_ext = fits.ImageHDU(flatcor.reshape(wave_shape), name=slitlet_id) outfile.append(outfile_ext) # this is the file to hold the image of the comparison values complfile_ext = fits.ImageHDU(delf.reshape(delf_shape), name=slitlet_id) complfile.append(complfile_ext) # the file is not yet written, indicate that this slit was appended to list to be written msg = "Extension corresponding to slitlet " + slitlet_id + " appended to list to be written into calculated and comparison fits files." print(msg) log_msgs.append(msg) if writefile: outfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_calc.fits") complfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_comp.fits") # this is the file to hold the image of pipeline-calculated difference values outfile.writeto(outfile_name, overwrite=True) # this is the file to hold the image of pipeline-calculated difference values complfile.writeto(complfile_name, overwrite=True) msg = "\nFits file with calculated flat values of each slit saved as: " print(msg) print(outfile_name) log_msgs.append(msg) log_msgs.append(outfile_name) msg = "Fits file with comparison (pipeline flat - calculated flat) saved as: " print(msg) print(complfile_name) log_msgs.append(msg) log_msgs.append(complfile_name) # If all tests passed then pytest will be marked as PASSED, else it will be FAILED FINAL_TEST_RESULT = True for t in total_test_result: if t == "FAILED": FINAL_TEST_RESULT = False break if FINAL_TEST_RESULT: msg = "\n *** Final result for flat_field test will be reported as PASSED *** \n" print(msg) log_msgs.append(msg) result_msg = "All slitlets PASSED flat_field test." else: msg = "\n *** Final result for flat_field test will be reported as FAILED *** \n" print(msg) log_msgs.append(msg) result_msg = "One or more slitlets FAILED flat_field test." # end the timer flattest_end_time = time.time() - flattest_start_time if flattest_end_time > 60.0: flattest_end_time = flattest_end_time / 60.0 # in minutes flattest_tot_time = "* Script flattest_mos.py took ", repr( flattest_end_time) + " minutes to finish." if flattest_end_time > 60.0: flattest_end_time = flattest_end_time / 60. # in hours flattest_tot_time = "* Script flattest_mos.py took ", repr( flattest_end_time) + " hours to finish." else: flattest_tot_time = "* Script flattest_mos.py took ", repr( flattest_end_time) + " seconds to finish." print(flattest_tot_time) log_msgs.append(flattest_tot_time) return FINAL_TEST_RESULT, result_msg, log_msgs
def flattest(step_input_filename, dflatref_path=None, sfile_path=None, fflat_path=None, writefile=True, show_figs=True, save_figs=False, plot_name=None, threshold_diff=1.0e-7, debug=False): """ This function calculates the difference between the pipeline and the calculated flat field values. The functions uses the output of the compute_world_coordinates.py script. Args: step_input_filename: str, name of the output fits file from the 2d_extract step (with full path) dflatref_path: str, path of where the D-flat reference fits files sfile_path: str, path of where the S-flat reference fits files fflat_path: str, path of where the F-flat reference fits files writefile: boolean, if True writes the fits files of the calculated flat and difference images show_figs: boolean, whether to show plots or not save_figs: boolean, save the plots (the 3 plots can be saved or not independently with the function call) plot_name: string, desired name (if name is not given, the plot function will name the plot by default) threshold_diff: float, threshold difference between pipeline output and ESA file debug: boolean, if true a series of print statements will show on-screen Returns: - 1 plot, if told to save and/or show them. - median_diff: Boolean, True if smaller or equal to threshold. - log_msgs: list, all print statements are captured in this variable """ log_msgs = [] # start the timer flattest_start_time = time.time() # get info from the rate file header det = fits.getval(step_input_filename, "DETECTOR", 0) msg = 'step_input_filename=' + step_input_filename print(msg) log_msgs.append(msg) exptype = fits.getval(step_input_filename, "EXP_TYPE", 0) grat = fits.getval(step_input_filename, "GRATING", 0) filt = fits.getval(step_input_filename, "FILTER", 0) msg = "flat_field_file --> Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype print(msg) log_msgs.append(msg) # read in the on-the-fly flat image flatfile = step_input_filename.replace("flat_field.fits", "interpolatedflat.fits") # flatfile = step_input_filename.replace("flat_field.fits", "intflat.fits") # for testing code only! # get the reference files # D-Flat dflat_ending = "f_01.03.fits" t = (dflatref_path, "nrs1", dflat_ending) dfile = "_".join(t) if det == "NRS2": dfile = dfile.replace("nrs1", "nrs2") msg = "Using D-flat: " + dfile print(msg) log_msgs.append(msg) dfim = fits.getdata(dfile, "SCI") dfimdq = fits.getdata(dfile, "DQ") # need to flip/rotate the image into science orientation ns = np.shape(dfim) dfim = np.transpose(dfim, (0, 2, 1)) # keep in mind that 0,1,2 = z,y,x in Python, whereas =x,y,z in IDL dfimdq = np.transpose(dfimdq) if det == "NRS2": # rotate science data by 180 degrees for NRS2 dfim = dfim[..., ::-1, ::-1] dfimdq = dfimdq[..., ::-1, ::-1] naxis3 = fits.getval(dfile, "NAXIS3", 1) if debug: print('np.shape(dfim) =', np.shape(dfim)) print('np.shape(dfimdq) =', np.shape(dfimdq)) # get the wavelength values dfwave = np.array([]) for i in range(naxis3): t = ("PFLAT", str(i + 1)) keyword = "_".join(t) dfwave = np.append(dfwave, fits.getval(dfile, keyword, 1)) dfrqe = fits.getdata(dfile, 2) # S-flat mode = "FS" if filt == "F070LP": flat = "FLAT4" elif filt == "F100LP": flat = "FLAT1" elif filt == "F170LP": flat = "FLAT2" elif filt == "F290LP": flat = "FLAT3" elif filt == "CLEAR": flat = "FLAT5" else: msg = "No filter correspondence. Exiting the program." print(msg) log_msgs.append(msg) result_msg = "Test skiped because there is no flat correspondance for the filter in the data: {}".format(filt) median_diff = "skip" return median_diff, result_msg, log_msgs sflat_ending = "f_01.01.fits" if mode in sfile_path: t = (sfile_path, grat, "OPAQUE", flat, "nrs1", sflat_ending) sfile = "_".join(t) else: msg = "Wrong path in for mode S-flat. This script handles mode " + mode + "only." print(msg) log_msgs.append(msg) # This is the key argument for the assert pytest function result_msg = "Wrong path in for mode S-flat. Test skiped because mode is not FS." median_diff = "skip" return median_diff, result_msg, log_msgs if debug: print("grat = ", grat) print("flat = ", flat) print("sfile used = ", sfile) if det == "NRS2": sfile = sfile.replace("nrs1", "nrs2") msg = "Using S-flat: " + sfile print(msg) log_msgs.append(msg) sfim = fits.getdata(sfile, "SCI") # 1) sfimdq = fits.getdata(sfile, "DQ") # 3) # need to flip/rotate image into science orientation sfim = np.transpose(sfim) sfimdq = np.transpose(sfimdq) if det == "NRS2": # rotate science data by 180 degrees for NRS2 sfim = sfim[..., ::-1, ::-1] sfimdq = sfimdq[..., ::-1, ::-1] if debug: print("np.shape(sfim) = ", np.shape(sfim)) print("np.shape(sfimdq) = ", np.shape(sfimdq)) sf = fits.open(sfile) print(sf.info()) try: sfv_a2001 = fits.getdata(sfile, "SLIT_A_200_1") sfv_a2002 = fits.getdata(sfile, "SLIT_A_200_2") sfv_a400 = fits.getdata(sfile, "SLIT_A_400") sfv_a1600 = fits.getdata(sfile, "SLIT_A_1600") except KeyError: print(" * S-Flat-Field file does not have extensions for slits 200A1, 200A2, 400A, or 1600A, trying with 200B") if det == "NRS2": sfv_b200 = fits.getdata(sfile, "SLIT_B_200") # F-Flat fflat_ending = "01.01.fits" if mode in fflat_path: ffile = "_".join((fflat_path, filt, fflat_ending)) else: msg = "Wrong path in for mode F-flat. This script handles mode " + mode + "only." print(msg) log_msgs.append(msg) # This is the key argument for the assert pytest function median_diff = "skip" return median_diff, msg, log_msgs msg = "Using F-flat: " + ffile print(msg) log_msgs.append(msg) ffv = fits.getdata(ffile, 1) # now go through each pixel in the test data # get the datamodel from the assign_wcs output file extract2d_wcs_file = step_input_filename.replace("_flat_field.fits", "_extract_2d.fits") model = datamodels.MultiSlitModel(extract2d_wcs_file) if writefile: # create the fits list to hold the calculated flat values for each slit hdu0 = fits.PrimaryHDU() outfile = fits.HDUList() outfile.append(hdu0) # create the fits list to hold the image of pipeline-calculated difference values hdu0 = fits.PrimaryHDU() complfile = fits.HDUList() complfile.append(hdu0) # list to determine if pytest is passed or not total_test_result = [] # loop over the slits sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"] msg = "Now looping through the slits. This may take a while... " print(msg) log_msgs.append(msg) if det == "NRS2": sltname_list.append("S200B1") # but check if data is BOTS if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": sltname_list = ["S1600A1"] # get all the science extensions sci_ext_list = auxfunc.get_sci_extensions(flatfile) # do the loop over the slits for slit_id in sltname_list: continue_flat_field_test = False if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": slit = model continue_flat_field_test = True else: for slit_in_MultiSlitModel in model.slits: if slit_in_MultiSlitModel.name == slit_id: slit = slit_in_MultiSlitModel continue_flat_field_test = True break if not continue_flat_field_test: continue else: # select the appropriate S-flat fast vector if slit_id == "S200A1": sfv = sfv_a2001 if slit_id == "S200A2": sfv = sfv_a2002 if slit_id == "S400A1": sfv = sfv_a400 if slit_id == "S1600A1": sfv = sfv_a1600 if slit_id == "S200B1": sfv = sfv_b200 msg = "\nWorking with slit: " + slit_id print(msg) log_msgs.append(msg) try: ext = sci_ext_list[slit_id] # this is for getting the science extension in the pipeline calculated flat except KeyError: # the extension does not exist in the file, look for the next slit name #ext = sltname_list.index(slit_id) continue # get the wavelength # slit.x(y)start are 1-based, turn them to 0-based for extraction # xstart, xend = slit.xstart - 1, slit.xstart -1 + slit.xsize # ystart, yend = slit.ystart - 1, slit.ystart -1 + slit.ysize # y, x = np.mgrid[ystart: yend, xstart: xend] x, y = wcstools.grid_from_bounding_box(slit.meta.wcs.bounding_box, step=(1, 1), center=True) ra, dec, wave = slit.meta.wcs(x, y) # wave is in microns # detector2slit = slit.meta.wcs.get_transform('detector', 'slit_frame') # sx, sy, ls = detector2slit(x, y) # world_coordinates = np.array([wave, ra, dec, sy])#, x, y]) # msg = "wavelength array: "+repr(wave[19, 483]) # print(msg) # log_msgs.append(msg) # get the subwindow origin px0 = slit.xstart - 1 + model.meta.subarray.xstart py0 = slit.ystart - 1 + model.meta.subarray.ystart msg = " Subwindow origin: px0=" + repr(px0) + " py0=" + repr(py0) print(msg) log_msgs.append(msg) n_p = np.shape(wave) nw = n_p[0] * n_p[1] nw1, nw2 = n_p[1], n_p[0] # remember that x=nw1 and y=nw2 are reversed in Python if debug: print(" nw1, nw2, nw = ", nw1, nw2, nw) delf = np.zeros([nw2, nw1]) + 999.0 flatcor = np.zeros([nw2, nw1]) + 999.0 # read the pipeline-calculated flat image # there are four extensions in the flatfile: SCI, DQ, ERR, WAVELENGTH pipeflat = fits.getdata(flatfile, ext) # make sure the two arrays are the same shape if np.shape(flatcor) != np.shape(pipeflat): msg1 = 'WARNING -> Something went wrong, arrays are not the same shape:' msg2 = 'np.shape(flatcor) = ' + repr(np.shape(flatcor)) + ' np.shape(pipeflat) = ' + repr( np.shape(pipeflat)) msg3 = 'Mathematical operations will fail. Exiting the loop here and setting test as FAILED.' print(msg1) print(msg2) print(msg3) log_msgs.append(msg1) log_msgs.append(msg2) log_msgs.append(msg3) msg = " *** Result of the test: " + test_result + "\n" ''' msg = 'Forcing arrays to be the same length.' n_p = np.shape(pipeflat) delf = np.zeros(n_p) + 999.0 flatcor = np.zeros(n_p) + 999.0 ''' print(msg) log_msgs.append(msg) test_result = "FAILED" total_test_result.append(test_result) continue # loop through the wavelengths msg = " Looping through the wavelengths... " print(msg) log_msgs.append(msg) for j in range(nw1): # in x for k in range(nw2): # in y if np.isfinite(wave[k, j]): # skip if wavelength is NaN # get thr full-frame pixel indeces for D- and S-flat image components pind = [k + py0 - 1, j + px0 - 1] # get the pixel bandwidth if (j != 0) and (j < nw1 - 1): if np.isfinite(wave[k, j + 1]) and np.isfinite(wave[k, j - 1]): delw = 0.5 * (wave[k, j + 1] - wave[k, j - 1]) if np.isfinite(wave[k, j + 1]) and not np.isfinite(wave[k, j - 1]): delw = wave[k, j + 1] - wave[k, j] if not np.isfinite(wave[k, j + 1]) and np.isfinite(wave[k, j - 1]): delw = wave[k, j] - wave[k, j - 1] if j == 0: delw = wave[k, j + 1] - wave[k, j] if j == nw - 1: delw = wave[k, j] - wave[k, j - 1] # integrate over D-flat fast vector dfrqe_wav = dfrqe.field("WAVELENGTH") dfrqe_rqe = dfrqe.field("RQE") iw = np.where((dfrqe_wav >= wave[k, j] - delw / 2.) & (dfrqe_wav <= wave[k, j] + delw / 2.)) int_tab = auxfunc.idl_tabulate(dfrqe_wav[iw], dfrqe_rqe[iw]) first_dfrqe_wav, last_dfrqe_wav = dfrqe_wav[iw[0]][0], dfrqe_wav[iw[0]][-1] dff = int_tab / (last_dfrqe_wav - first_dfrqe_wav) if debug: print("np.shape(dfrqe_wav) : ", np.shape(dfrqe_wav)) print("np.shape(dfrqe_rqe) : ", np.shape(dfrqe_rqe)) print("dfimdq[pind[0],[pind[1]] : ", dfimdq[pind[0], pind[1]]) print("np.shape(iw) =", np.shape(iw)) print("np.shape(dfrqe_wav) = ", np.shape(dfrqe_wav[iw])) print("np.shape(dfrqe_rqe) = ", np.shape(dfrqe_rqe[iw])) print("int_tab=", int_tab) print("np.shape(dfim) = ", np.shape(dfim)) print("dff = ", dff) # interpolate over D-flat cube iloc = auxfunc.idl_valuelocate(dfwave, wave[k, j])[0] if dfwave[iloc] > wave[k, j]: iloc -= 1 ibr = [iloc] if iloc != len(dfwave) - 1: ibr.append(iloc + 1) # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0] zz = dfim[:, pind[0], pind[1]][ibr] # now determine the length of the array with only the finite numbers zzwherenonan = np.where(np.isfinite(zz)) kk = np.size(zzwherenonan) dfs = 1.0 if (wave[k, j] <= max(dfwave)) and (wave[k, j] >= min(dfwave)) and (kk == 2): dfs = np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan]) # check DQ flags if dfimdq[pind[0]][pind[1]] != 0: dfs = 1.0 if debug: print("wave[k, j] = ", wave[k, j]) print("iloc = ", iloc) print("ibr = ", ibr) print("np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan]) = ", np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan])) print("dfs = ", dfs) # integrate over S-flat fast vector sfv_wav = sfv.field("WAVELENGTH") sfv_dat = sfv.field("DATA") iw = np.where((sfv_wav >= wave[k, j] - delw / 2.0) & (sfv_wav <= wave[k, j] + delw / 2.0)) sff = 1.0 if np.size(iw) > 2: int_tab = auxfunc.idl_tabulate(sfv_wav[iw], sfv_dat[iw]) first_sfv_wav, last_sfv_wav = sfv_wav[iw[0]][0], sfv_wav[iw[0]][-1] sff = int_tab / (last_sfv_wav - first_sfv_wav) # get s-flat pixel-dependent correction sfs = 1.0 if sfimdq[pind[0], pind[1]] == 0: sfs = sfim[pind[0], pind[1]] if debug: print("np.shape(iw) =", np.shape(iw)) print("np.shape(sfv_wav) = ", np.shape(sfv_wav)) print("np.shape(sfv_dat) = ", np.shape(sfv_dat)) print("int_tab = ", int_tab) print("sff = ", sff) print("sfs = ", sfs) # integrate over F-flat fast vector # reference file blue cutoff is 1 micron, so need to force solution for shorter wavs ffv_wav = ffv.field("WAVELENGTH") ffv_dat = ffv.field("DATA") fff = 1.0 if wave[k, j] - delw / 2.0 >= 1.0: iw = np.where((ffv_wav >= wave[k, j] - delw / 2.0) & (ffv_wav <= wave[k, j] + delw / 2.0)) if np.size(iw) > 1: int_tab = auxfunc.idl_tabulate(ffv_wav[iw], ffv_dat[iw]) first_ffv_wav, last_ffv_wav = ffv_wav[iw[0]][0], ffv_wav[iw[0]][-1] fff = int_tab / (last_ffv_wav - first_ffv_wav) flatcor[k, j] = dff * dfs * sff * sfs * fff if debug: print("np.shape(iw) =", np.shape(iw)) print("np.shape(ffv_wav) = ", np.shape(ffv_wav)) print("np.shape(ffv_dat) = ", np.shape(ffv_dat)) print("fff = ", fff) print("flatcor[k, j] = ", flatcor[k, j]) print("dff, dfs, sff, sfs, fff:", dff, dfs, sff, sfs, fff) try: # Difference between pipeline and calculated values delf[k, j] = pipeflat[k, j] - flatcor[k, j] if debug: print("delf[k, j] = ", delf[k, j]) # Remove all pixels with values=1 (outside slit boundaries) for statistics if pipeflat[k, j] == 1: delf[k, j] = 999.0 if np.isnan(wave[k, j]): flatcor[k, j] = 1.0 # no correction if no wavelength if debug: print("flatcor[k, j] = ", flatcor[k, j]) print("delf[k, j] = ", delf[k, j]) except: IndexError if debug: no_999 = delf[np.where(delf != 999.0)] print("np.shape(no_999) = ", np.shape(no_999)) alldelf = delf.flatten() print("median of the whole array: ", np.median(alldelf)) print("median, stdev in delf: ", np.median(no_999), np.std(no_999)) neg_vals = no_999[np.where(no_999 < 0.0)] print("neg_vals = ", np.shape(neg_vals)) print("np.shape(delf) = ", np.shape(delf)) print("np.shape(delfg) = ", np.shape(delfg)) nanind = np.isnan(delf) # get all the nan indexes notnan = ~nanind # get all the not-nan indexes delf = delf[notnan] # get rid of NaNs if delf.size == 0: msg1 = " * Unable to calculate statistics because difference array has all values as NaN. Test will be set to FAILED." print(msg1) log_msgs.append(msg1) test_result = "FAILED" delfg_mean, delfg_median, delfg_std = np.nan, np.nan, np.nan stats = [delfg_mean, delfg_median, delfg_std] else: msg = "Calculating statistics... " print(msg) log_msgs.append(msg) delfg = delf[np.where((delf != 999.0) & (delf < 0.1) & (delf > -0.1))] # ignore outliers if delfg.size == 0: msg1 = " * Unable to calculate statistics because difference array has all outlier values. Test will be set to FAILED." print(msg1) log_msgs.append(msg1) test_result = "FAILED" delfg_mean, delfg_median, delfg_std = np.nan, np.nan, np.nan stats = [delfg_mean, delfg_median, delfg_std] else: stats_and_strings = auxfunc.print_stats(delfg, "Flat Difference", float(threshold_diff), abs=True) stats, stats_print_strings = stats_and_strings delfg_mean, delfg_median, delfg_std = stats for msg in stats_print_strings: log_msgs.append(msg) # This is the key argument for the assert pytest function median_diff = False if abs(delfg_median) <= float(threshold_diff): median_diff = True if median_diff: test_result = "PASSED" else: test_result = "FAILED" msg = " *** Result of the test: " + test_result + "\n" print(msg) log_msgs.append(msg) total_test_result.append(test_result) # make histogram if show_figs or save_figs: # set plot variables main_title = filt + " " + grat + " SLIT=" + slit_id + "\n" bins = None # binning for the histograms, if None the function will select them automatically # lolim_x, uplim_x, lolim_y, uplim_y plt_origin = None # Residuals img and histogram title = main_title + "Residuals" info_img = [title, "x (pixels)", "y (pixels)"] xlabel, ylabel = "flat$_{pipe}$ - flat$_{calc}$", "N" info_hist = [xlabel, ylabel, bins, stats] if delfg.size != 0 and delfg[1] is np.nan: msg = "Unable to create plot of relative wavelength difference." print(msg) log_msgs.append(msg) else: file_path = step_input_filename.replace(os.path.basename(step_input_filename), "") file_basename = os.path.basename(step_input_filename.replace(".fits", "")) t = (file_basename, "FS_flattest_" + slit_id + "_histogram.pdf") plt_name = "_".join(t) plt_name = os.path.join(file_path, plt_name) difference_img = (pipeflat - flatcor) # /flatcor in_slit = np.logical_and(difference_img < 900.0, difference_img > -900.0) # ignore points out of the slit, difference_img[~in_slit] = np.nan # Set values outside the slit to NaN # nanind = np.isnan(difference_img) # get all the nan indexes # difference_img[nanind] = np.nan # set all nan indexes to have a value of nan vminmax = [-5 * delfg_std, 5 * delfg_std] # set the range of values to be shown in the image, will affect color scale auxfunc.plt_two_2Dimgandhist(difference_img, delfg, info_img, info_hist, plt_name=plt_name, vminmax=vminmax, plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs) elif not save_figs and not show_figs: msg = "Not making plots because both show_figs and save_figs were set to False." print(msg) log_msgs.append(msg) elif not save_figs: msg = "Not saving plots because save_figs was set to False." print(msg) log_msgs.append(msg) # create fits file to hold the calculated flat for each slit if writefile: msg = "Saving the fits files with the calculated flat for each slit..." print(msg) log_msgs.append(msg) # this is the file to hold the image of pipeline-calculated difference values outfile_ext = fits.ImageHDU(flatcor, name=slit_id) outfile.append(outfile_ext) # this is the file to hold the image of pipeline-calculated difference values complfile_ext = fits.ImageHDU(delf, name=slit_id) complfile.append(complfile_ext) # the file is not yet written, indicate that this slit was appended to list to be written msg = "Extension " + repr( i) + " appended to list to be written into calculated and comparison fits files." print(msg) log_msgs.append(msg) if writefile: outfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_calc.fits") complfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_comp.fits") # create the fits list to hold the calculated flat values for each slit outfile.writeto(outfile_name, overwrite=True) # this is the file to hold the image of pipeline-calculated difference values complfile.writeto(complfile_name, overwrite=True) msg = "\nFits file with calculated flat values of each slit saved as: " print(msg) log_msgs.append(msg) print(outfile_name) log_msgs.append(outfile_name) msg = "Fits file with comparison (pipeline flat - calculated flat) saved as: " print(msg) log_msgs.append(msg) print(complfile_name) log_msgs.append(complfile_name) # If all tests passed then pytest will be marked as PASSED, else it will be FAILED FINAL_TEST_RESULT = False for t in total_test_result: if t == "FAILED": FINAL_TEST_RESULT = False break else: FINAL_TEST_RESULT = True if FINAL_TEST_RESULT: msg = "\n *** Final result for flat_field test will be reported as PASSED *** \n" print(msg) log_msgs.append(msg) result_msg = "All slits PASSED flat_field test." else: msg = "\n *** Final result for flat_field test will be reported as FAILED *** \n" print(msg) log_msgs.append(msg) result_msg = "One or more slits FAILED flat_field test." # end the timer flattest_end_time = time.time() - flattest_start_time if flattest_end_time > 60.0: flattest_end_time = flattest_end_time / 60.0 # in minutes flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " minutes to finish." if flattest_end_time > 60.0: flattest_end_time = flattest_end_time / 60. # in hours flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " hours to finish." else: flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " seconds to finish." print(flattest_tot_time) log_msgs.append(flattest_tot_time) return FINAL_TEST_RESULT, result_msg, log_msgs
def contam_corr(input_model, waverange, photom, max_cores): """ The main WFSS contamination correction function Parameters ---------- input_model : `~jwst.datamodels.MultiSlitModel` Input data model containing 2D spectral cutouts waverange : `~jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model max_cores : string Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). Returns ------- output_model : `~jwst.datamodels.MultiSlitModel` A copy of the input_model that has been corrected simul_model : `~jwst.datamodels.ImageModel` Full-frame simulated image of the grism exposure contam_model : `~jwst.datamodels.MultiSlitModel` Contamination estimate images for each source slit """ # Determine number of cpu's to use for multi-processing if max_cores == 'none': ncpus = 1 else: num_cores = multiprocessing.cpu_count() if max_cores == 'quarter': ncpus = num_cores // 4 or 1 elif max_cores == 'half': ncpus = num_cores // 2 or 1 elif max_cores == 'all': ncpus = num_cores else: ncpus = 1 log.debug(f"Found {num_cores} cores; using {ncpus}") # Initialize output model output_model = input_model.copy() # Get the segmentation map for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) # Get the direct image from which the segmentation map was constructed direct_file = input_model.meta.direct_image image_names = [direct_file] log.debug(f"Direct image names={image_names}") # Get the grism WCS from the input model grism_wcs = input_model.slits[0].meta.wcs # Find out how many spectral orders are defined, based on the # array of order values in the Wavelengthrange ref file spec_orders = np.asarray(waverange.order) spec_orders = spec_orders[spec_orders != 0] # ignore any order 0 entries log.debug(f"Spectral orders defined = {spec_orders}") # Get the FILTER and PUPIL wheel positions, for use later filter_kwd = input_model.meta.instrument.filter pupil_kwd = input_model.meta.instrument.pupil # NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel # with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just # the opposite. It has gratings in the FILTER wheel and filters in the # PUPIL wheel. So when processing NIRISS grism exposures the name of # filter needs to come from the PUPIL keyword value. if input_model.meta.instrument.name == 'NIRISS': filter_name = pupil_kwd else: filter_name = filter_kwd # Load lists of wavelength ranges and flux cal info for all orders wmin = {} wmax = {} sens_waves = {} sens_response = {} for order in spec_orders: wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) wmin[order] = wavelength_range[order][0] wmax[order] = wavelength_range[order][1] # Load the sensitivity (inverse flux cal) data for this mode and order sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order) log.debug(f"wmin={wmin}, wmax={wmax}") # Initialize the simulated image object simul_all = None obs = Observation(image_names, seg_model, grism_wcs, filter_name, boundaries=[0, 2047, 0, 2047], max_cpu=ncpus) # Create simulated grism image for each order and sum them up for order in spec_orders: log.info(f"Creating full simulated grism image for order {order}") obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) # Accumulate result for this order into the combined image if simul_all is None: simul_all = obs.simulated_image else: simul_all += obs.simulated_image # Save the full-frame simulated grism image simul_model = datamodels.ImageModel(data=simul_all) simul_model.update(input_model, only="PRIMARY") # Loop over all slits/sources to subtract contaminating spectra log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() contam_model.update(input_model) slits = [] for slit in output_model.slits: # Create simulated spectrum for this source only sid = slit.source_id order = slit.meta.wcsinfo.spectral_order chunk = np.where(obs.IDs == sid)[0][0] # find chunk for this source obs.simulated_image = np.zeros(obs.dims) obs.disperse_chunk(chunk, order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) this_source = obs.simulated_image # Contamination estimate is full simulated image minus this source contam = simul_all - this_source # Create a cutout of the contam image that matches the extent # of the source slit x1 = slit.xstart - 1 x2 = x1 + slit.xsize y1 = slit.ystart - 1 y2 = y1 + slit.ysize cutout = contam[y1:y2, x1:x2] new_slit = datamodels.SlitModel(data=cutout) copy_slit_info(slit, new_slit) slits.append(new_slit) # Subtract the cutout from the source slit slit.data -= cutout # Save the contamination estimates for all slits contam_model.slits.extend(slits) # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = 'COMPLETE' return output_model, simul_model, contam_model
def pathtest(step_input_filename, reffile, comparison_filename, writefile=True, show_figs=False, save_figs=True, threshold_diff=1.0e-7, debug=False): """ This function calculates the difference between the pipeline and calculated pathloss values. Args: step_input_filename: str, full path name of sourcetype output fits file reffile: str, path to the pathloss FS reference fits file comparison_filename: str, path to pipeline-generated pathloss fits file writefile: boolean, if True writes the fits files of calculated flat and difference images show_figs: boolean, whether to show plots or not save_figs: boolean, whether to save the plots or not plot_name: string, desired name. If not given, plot has default name threshold_diff: float, threshold difference between pipeline output and comparison file debug: boolean, if true print statements will show on-screen Returns: - 1 plot, if told to save and/or show them. - median_diff: Boolean, True if smaller or equal to threshold. - log_msgs: list, all print statements are captured in this variable """ log_msgs = [] # start the timer pathtest_start_time = time.time() # get info from the rate file header det = fits.getval(step_input_filename, "DETECTOR", 0) msg = 'step_input_filename=' + step_input_filename print(msg) log_msgs.append(msg) exptype = fits.getval(step_input_filename, "EXP_TYPE", 0) grat = fits.getval(step_input_filename, "GRATING", 0) filt = fits.getval(step_input_filename, "FILTER", 0) msg = "pathloss file: Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype print(msg) log_msgs.append(msg) is_point_source = False # get the datamodel from the assign_wcs output file extract2d_wcs_file = step_input_filename.replace("_srctype.fits", "_extract_2d.fits") model = datamodels.MultiSlitModel(extract2d_wcs_file) if writefile: # create the fits list to hold the calculated pathloss values for each slit hdu0 = fits.PrimaryHDU() outfile = fits.HDUList() outfile.append(hdu0) # create fits list to hold pipeline-calculated difference values hdu0 = fits.PrimaryHDU() compfile = fits.HDUList() compfile.append(hdu0) # list to determine if pytest is passed or not total_test_result = [] print('Checking files exist & obtaining datamodels, takes a few mins...') if os.path.isfile(comparison_filename): if debug: print('Comparison file does exist.') else: result_msg = 'Comparison file does NOT exist. Skipping pathloss test.' log_msgs.append(result_msg) result = 'skip' return result, result_msg, log_msgs # get the comparison data model pathloss_pipe = datamodels.open(comparison_filename) # For the moment, the pipeline is using the wrong reference file for slit 400A1, so read file that # re-processed with the right reference file and open corresponding data model if os.path.isfile( step_input_filename.replace("srctype.fits", "pathloss_400A1.fits")): pathloss_400a1 = step_input_filename.replace("srctype.fits", "pathloss_400A1.fits") pathloss_pipe_400a1 = datamodels.open(pathloss_400a1) if debug: print('got comparison datamodel!') if os.path.isfile(step_input_filename): if debug: print('Input file does exist.') else: result_msg = 'Input file does NOT exist. Skipping pathloss test.' log_msgs.append(result_msg) result = 'skip' return result, result_msg, log_msgs # get the input data model pl = datamodels.open(step_input_filename) if debug: print('got input datamodel!') msg = "Now looping through the slits. This may take a while... " print(msg) log_msgs.append(msg) sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"] if det == "NRS2": sltname_list.append("S200B1") # but check if data is BOTS if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": sltname_list = ["S1600A1"] # get all the science extensions ps_uni_ext_list = get_ps_uni_extensions(reffile, is_point_source) slit_val = 0 for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits): slit_val = slit_val + 1 slit_id = slit.name #if slit_id == 'S400A1': # continue continue_pathloss_test = False if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ": slit = model continue_pathloss_test = True else: for slit_in_MultiSlitModel in model.slits: if slit_in_MultiSlitModel.name == slit_id: slit = slit_in_MultiSlitModel continue_pathloss_test = True break if not continue_pathloss_test: continue else: try: if is_point_source is True: ext = ps_uni_ext_list[0][slit_id] print("Retrieved point source extension") elif is_point_source is False: ext = ps_uni_ext_list[1][slit_id] print("Retrieved extended source extension for {}".format( slit_val)) except KeyError: # gets index associted with slit if issue above ext = sltname_list.index(slit_id) print("Unable to retrieve extension.") wcs_obj = slit.meta.wcs # get the wavelength x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box, step=(1, 1), center=True) ra, dec, wave = wcs_obj(x, y) wave_sci = wave * 10**(-6) # microns --> meters # adjustments for S400A1 if slit_id == "S400A1": if is_point_source: ext = 1 else: ext = 3 print( "Got uniform source extension frome extra reference file") reffile2use = "jwst-nirspec-a400.plrf.fits" else: reffile2use = reffile print("Using reference file {}".format(reffile2use)) plcor_ref_ext = fits.getdata(reffile2use, ext) hdul = fits.open(reffile2use) plcor_ref = hdul[1].data w = wcs.WCS(hdul[1].header) w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref. shape[1], :plcor_ref.shape[2]] slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0) previous_sci = slit.data if slit_id == 'S400A1': if pathloss_pipe_400a1 is not None: for pipe_slit_400a1 in pathloss_pipe_400a1.slits: if pipe_slit_400a1.name == "S400A1": comp_sci = pipe_slit_400a1.data pipe_correction = pipe_slit_400a1.pathloss break else: continue else: comp_sci = pipe_slit.data pipe_correction = pipe_slit.pathloss if len(pipe_correction) == 0: print( "Pipeline pathloss correction in datamodel is empty. Skipping testing this slit." ) continue # set up generals for all the plots font = {'weight': 'normal', 'size': 7} matplotlib.rc('font', **font) corr_vals = np.interp(wave_sci, wave_ref[:, 0, 0], plcor_ref_ext) corrected_array = previous_sci / corr_vals # Plots: step_input_filepath = step_input_filename.replace(".fits", "") # my correction values fig = plt.figure() plt.subplot(321) norm = ImageNormalize(corr_vals) plt.imshow(corr_vals, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Calculated Correction') plt.colorbar() # pipe corerction plt.subplot(322) norm = ImageNormalize(pipe_correction) plt.imshow(pipe_correction, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Pathloss Correction Comparison') plt.colorbar() # residuals (pipe correction - my correction) corr_residuals = pipe_correction - corr_vals plt.subplot(323) norm = ImageNormalize(corr_residuals) plt.imshow(corr_residuals, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Correction residuals') plt.colorbar() # pipe science data before plt.subplot(324) norm = ImageNormalize(previous_sci) plt.imshow(previous_sci, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Normalized pipeline science data before pathloss') plt.colorbar() # pipe science data after plt.subplot(325) norm = ImageNormalize(comp_sci) plt.imshow(comp_sci, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.title('Normalized pipeline science data after pathloss') plt.colorbar() # pipe science data after pathloss plt.subplot(326) norm = ImageNormalize(corrected_array) plt.imshow(corrected_array, norm=norm, aspect=10.0, origin='lower', cmap='viridis') plt.title('My science data after pathloss') plt.xlabel('dispersion in pixels') plt.ylabel('y in pixels') plt.colorbar() fig.suptitle( "FS UNI Pathloss Correction Testing for {}".format(slit_id)) # add space between the subplots fig.subplots_adjust(wspace=0.9) # Show and/or save figures if show_figs: plt.show() if save_figs: plt_name = step_input_filepath + "_Pathloss_test_slit_" + str( slit_id) + "_FS_extended.png" plt.savefig(plt_name) print('Figure saved as: ', plt_name) elif not save_figs and not show_figs: msg = "Not making plots because both show_figs and save_figs were set to False." if debug: print(msg) log_msgs.append(msg) elif not save_figs: msg = "Not saving plots because save_figs was set to False." if debug: print(msg) log_msgs.append(msg) plt.clf() # create fits file to hold the calculated pathloss for each slit if writefile: msg = "Saving the fits files with the calculated pathloss for each slit..." print(msg) log_msgs.append(msg) # this is the file to hold the image of pipeline-calculated difference values outfile_ext = fits.ImageHDU(corr_vals, name=slit_id) outfile.append(outfile_ext) # this is the file to hold the image of pipeline-calculated difference values compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id) compfile.append(compfile_ext) # Histogram ax = plt.subplot(212) plt.hist(corr_residuals[~np.isnan(corr_residuals)], bins=100, range=(-0.00000013, 0.00000013)) plt.title('Residuals Histogram') plt.xlabel("Correction Value") plt.ylabel("Number of Occurences") nanind = np.isnan(corr_residuals) # get all the nan indexes notnan = ~nanind # get all the not-nan indexes arr_mean = np.mean(corr_residuals[notnan]) arr_median = np.median(corr_residuals[notnan]) arr_stddev = np.std(corr_residuals[notnan]) plt.axvline(arr_mean, label="mean = %0.3e" % (arr_mean), color="g") plt.axvline(arr_median, label="median = %0.3e" % (arr_median), linestyle="-.", color="b") str_arr_stddev = "stddev = {:0.3e}".format(arr_stddev) ax.text(0.73, 0.67, str_arr_stddev, transform=ax.transAxes, fontsize=7) plt.legend() plt.minorticks_on() # Show and/or save figures if save_figs: plt_name = step_input_filename.replace( ".fits", "") + "_Pathlosstest_slitlet_" + slit_id + ".png" plt.savefig(plt_name) print('Figure saved as: ', plt_name) if show_figs: plt.show() elif not save_figs and not show_figs: msg = "Not making plots because both show_figs and save_figs were set to False." if debug: print(msg) log_msgs.append(msg) elif not save_figs: msg = "Not saving plots because save_figs was set to False." if debug: print(msg) log_msgs.append(msg) plt.close() if corr_residuals[~np.isnan(corr_residuals)].size == 0: msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \ "Test will be set to FAILED." print(msg1) log_msgs.append(msg1) test_result = "FAILED" else: msg = "Calculating statistics... " print(msg) log_msgs.append(msg) corr_residuals = corr_residuals[ np.where((corr_residuals != 999.0) & (corr_residuals < 0.1) & (corr_residuals > -0.1))] # ignore outliers if corr_residuals.size == 0: msg1 = " * Unable to calculate statistics because difference array has all outlier values. Test " \ "will be set to FAILED." if debug: print(msg1) log_msgs.append(msg1) test_result = "FAILED" else: stats_and_strings = auxfunc.print_stats(corr_residuals, "Difference", float(threshold_diff), absolute=True) stats, stats_print_strings = stats_and_strings corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats for msg in stats_print_strings: log_msgs.append(msg) # This is the key argument for the assert pytest function median_diff = False if abs(corr_residuals_median) <= float(threshold_diff): median_diff = True if median_diff: test_result = "PASSED" else: test_result = "FAILED" msg = " *** Result of the test: " + test_result + "\n" if debug: print(msg) log_msgs.append(msg) total_test_result.append(test_result) if writefile: outfile_name = step_input_filename.replace("srctype", det + "_calcuated_pathloss") compfile_name = step_input_filename.replace( "srctype", det + "_comparison_pathloss") # create the fits list to hold the calculated pathloss values for each slit outfile.writeto(outfile_name, overwrite=True) # this is the file to hold the image of pipeline-calculated difference values compfile.writeto(compfile_name, overwrite=True) msg = "\nFits file with calculated pathloss values of each slit saved as: " print(msg) log_msgs.append(msg) print(outfile_name) log_msgs.append(outfile_name) msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: " print(msg) log_msgs.append(msg) print(compfile_name) log_msgs.append(compfile_name) # If all tests passed then pytest will be marked as PASSED, else it will be FAILED FINAL_TEST_RESULT = False for t in total_test_result: if t == "FAILED": FINAL_TEST_RESULT = False break else: FINAL_TEST_RESULT = True if FINAL_TEST_RESULT: msg = "\n *** Final result for path_loss test will be reported as PASSED *** \n" print(msg) log_msgs.append(msg) result_msg = "All slits PASSED path_loss test." else: msg = "\n *** Final result for path_loss test will be reported as FAILED *** \n" print(msg) log_msgs.append(msg) result_msg = "One or more slits FAILED path_loss test." # end the timer pathloss_end_time = time.time() - pathtest_start_time if pathloss_end_time > 60.0: pathloss_end_time = pathloss_end_time / 60.0 # in minutes pathloss_tot_time = "* Script FS_UNI.py took ", repr( pathloss_end_time) + " minutes to finish." if pathloss_end_time > 60.0: pathloss_end_time = pathloss_end_time / 60. # in hours pathloss_tot_time = "* Script FS_UNI.py took ", repr( pathloss_end_time) + " hours to finish." else: pathloss_tot_time = "* Script FS_UNI.py took ", repr( pathloss_end_time) + " seconds to finish." print(pathloss_tot_time) log_msgs.append(pathloss_tot_time) return FINAL_TEST_RESULT, result_msg, log_msgs