예제 #1
0
def slit_data_a():
    """Create "science" data for testing.

    Returns
    -------
    input_model : `~jwst.datamodels.MultiSlitModel`
    """

    # Create a MultiSlitModel object.
    data_shape = (5, 9)
    data = np.zeros(data_shape, dtype=np.float32) + 10.
    dq = np.zeros(data_shape, dtype=np.uint32)
    # One row of wavelengths.
    temp_wl = np.linspace(1.3,
                          4.8,
                          num=data_shape[1],
                          endpoint=True,
                          retstep=False,
                          dtype=np.float32)
    wavelength = np.zeros(data_shape, dtype=np.float32)
    # Shift the wavelength values from row to the next.
    dwl = 0.1
    for j in range(data_shape[0]):
        wavelength[j, :] = (temp_wl + j * dwl)
    wavelength = np.around(wavelength, 4)
    input_model = datamodels.MultiSlitModel()
    slit = datamodels.SlitModel(init=None,
                                data=data,
                                dq=dq,
                                wavelength=wavelength)
    input_model.slits.append(slit)

    return input_model
예제 #2
0
def test_nrs_msaspec():
    """Test for when exposure type is NRS_MSASPEC
    """
    input = datamodels.MultiSlitModel()
    input.meta.exposure.type = "NRS_MSASPEC"

    slits = [{
        'source_id': 1,
        'stellarity': 0.9
    }, {
        'source_id': 2,
        'stellarity': -1
    }, {
        'source_id': 3,
        'stellarity': 0.5
    }]

    for slit in slits:
        input.slits.append(slit)

    result = srctype.set_source_type(input)

    assert (result.slits[0].source_type == 'POINT')
    assert (result.slits[1].source_type == 'POINT')
    assert (result.slits[2].source_type == 'EXTENDED')
예제 #3
0
    def __init__(self, input_DM):
        """
        Short Summary
        -------------
        Set file name of persistence file

        Parameters
        ----------
        input_DM: data model object
            input Data Model object

        """
        try:
            self.input_file = input_DM
            model = datamodels.open(input_DM)
            # If model comes back as generic DataModel, reopen as MultiSlit
            if isinstance(model, datamodels.CubeModel) or isinstance(
                    model, datamodels.ImageModel):
                pass
            elif isinstance(model, datamodels.DataModel):
                model.close()
                model = datamodels.MultiSlitModel(input_DM)
            self.input = model
        except Exception as errmess:
            log.error('Error opening %s', input_DM)
            self.input = None
예제 #4
0
def create_background_from_multislit(input_model):
    """Create a 1D master background spectrum from a set of
    calibrated background MOS slitlets in the input
    MultiSlitModel.

    Parameters
    ----------
    input_model : `~jwst.datamodels.MultiSlitModel`
        The input data model containing all slit instances.

    Returns
    -------
    master_bkg: `~jwst.datamodels.CombinedSpecModel`
        The 1D master background spectrum created from the inputs.
    """
    from ..resample import resample_spec_step
    from ..extract_1d import extract_1d_step
    from ..combine_1d.combine1d import combine_1d_spectra

    log.info('Creating MOS master background from background slitlets')

    # Copy dedicated background slitlets to a temporary model
    bkg_model = datamodels.MultiSlitModel()
    bkg_model.update(input_model)
    slits = []
    for slit in input_model.slits:
        if "background" in slit.source_name:
            log.info(f'Using background slitlet {slit.source_name}')
            slits.append(slit)

    if len(slits) == 0:
        log.warning(
            'No background slitlets found; skipping master bkg correction')
        return None

    bkg_model.slits.extend(slits)

    # Apply resample_spec and extract_1d to all background slitlets
    log.info('Applying resampling and 1D extraction to background slits')
    resamp = resample_spec_step.ResampleSpecStep.call(bkg_model)
    x1d = extract_1d_step.Extract1dStep.call(resamp)

    # Call combine_1d to combine the 1D background spectra
    log.info('Combining 1D background spectra into master background')
    master_bkg = combine_1d_spectra(x1d, exptime_key='exposure_time')

    del bkg_model
    del resamp
    del x1d

    return master_bkg
예제 #5
0
def test_nrs_fixedslit():
    """Test for when exposure type is NRS_FIXEDSLIT
    """
    input = datamodels.MultiSlitModel()
    input.meta.exposure.type = "NRS_FIXEDSLIT"
    input.meta.instrument.fixed_slit = "S200A1"
    input.meta.target.source_type = 'EXTENDED'

    slits = [{'name': 'S200A2'}, {'name': 'S200A1'}, {'name': 'S1600A1'}]

    for slit in slits:
        input.slits.append(slit)

    result = srctype.set_source_type(input)

    assert (result.slits[0].source_type == 'POINT')
    assert (result.slits[1].source_type == 'EXTENDED')
    assert (result.slits[2].source_type == 'POINT')
예제 #6
0
def compute_world_coordinates(fname, output=None):
    """
    Computes wavelengths, and space coordinates of a NIRSPEC
    FS or MOS observation after running extract_2d.

    Parameters
    ----------
    fname : str
        The name of a file with extracted slits, i.e. the output
        of extract2d.
    output : str
        The name of the output file. If None the root of the input
        file is used with an extension world_coordinates.

    Examples
    --------
    >>> compute_world_coordinates('nrs1_fixed_assign_wcs_extract_2d.fits')

    """
    model = datamodels.MultiSlitModel(fname)
    if model.meta.exposure.type.lower() not in [
            'nrs_fixedslit', 'nrs_msaspec'
    ]:
        raise ValueError("Expected a FS or MOS observation,"
                         "(EXP_TYPE=NRS_FIXEDSLIT, NRS_MSASPEC),"
                         "got {0}".format(model.meta.exposure.type))
    hdulist = fits.HDUList()
    phdu = fits.PrimaryHDU()
    phdu.header['filename'] = model.meta.filename
    phdu.header['data'] = 'world coordinates'
    hdulist.append(phdu)
    output_frame = model.slits[0].meta.wcs.available_frames[-1]
    for slit in model.slits:
        # slit.x(y)start are 1-based, turn them to 0-based for extraction
        # xstart, xend = slit.xstart - 1, slit.xstart -1 + slit.xsize
        # ystart, yend = slit.ystart - 1, slit.ystart -1 + slit.ysize
        # y, x = np.mgrid[ystart: yend, xstart: xend]
        x, y = wcstools.grid_from_bounding_box(slit.meta.wcs.bounding_box,
                                               step=(1, 1),
                                               center=True)

        ra, dec, lam = slit.meta.wcs(x, y)
        detector2slit = slit.meta.wcs.get_transform('detector', 'slit_frame')

        sx, sy, ls = detector2slit(x, y)
        world_coordinates = np.array([lam, ra, dec, sy])  #, x, y])
        imhdu = fits.ImageHDU(data=world_coordinates)
        imhdu.header['PLANE1'] = 'lambda, microns'
        imhdu.header['PLANE2'] = '{0}_x, arcsec'.format(output_frame)
        imhdu.header['PLANE3'] = '{0}_y, arcsec'.format(output_frame)
        imhdu.header['PLANE4'] = 'slit_y, relative to center (0, 0)'

        imhdu.header['SLIT'] = slit.name
        # add the overall subarray offset
        imhdu.header['CRVAL1'] = slit.xstart - 1 + model.meta.subarray.xstart
        imhdu.header['CRVAL2'] = slit.ystart - 1 + model.meta.subarray.ystart

        imhdu.header['CRPIX1'] = 1
        imhdu.header['CRPIX2'] = 1
        imhdu.header['CTYPE1'] = 'pixel'
        imhdu.header['CTYPE2'] = 'pixel'
        hdulist.append(imhdu)
    if output is not None:
        base, ext = os.path.splitext(output)
        if ext != "fits":
            ext = "fits"
        if not base.endswith('world_coordinates'):
            "".join([base, '_world_coordinates'])
        "".join([base, ext])
    else:
        root = model.meta.filename.split('_')
        output = "".join([root[0], '_world_coordinates', '.fits'])
    hdulist.writeto(output, overwrite=True)
    del hdulist
    model.close()
예제 #7
0
def pathtest(step_input_filename,
             reffile,
             comparison_filename,
             writefile=True,
             show_figs=True,
             save_figs=False,
             threshold_diff=1e-7,
             debug=False):
    """
    This function calculates the difference between the pipeline and the
    calculated pathloss values. The functions use the output of the
    compute_world_coordinates.py script.
    Args:
        step_input_filename: str, name of the output fits file from the
        source type step (with full path) 
        reffile: str, path to the pathloss FS reference fits file
        comparison_filename: str, path to comparison pipeline pathloss file
        writefile: boolean, if True writes the fits files of the calculated
        pathloss and difference image.
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots (the 3 plots can be saved or not
        independently with the function call)
        function will name the plot by default)
        threshold_diff: float, threshold difference between pipeline output
        and comparison file
        debug: boolean, if true a series of print statements will show
        on-screen
    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable
    """

    log_msgs = []

    # start the timer
    pathtest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)

    msg = "path_loss_file  -->  Grating:" + grat + "   Filter:" + filt + "   EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    is_point_source = True

    # get the datamodel from the assign_wcs output file
    extract2d_wcs_file = step_input_filename.replace("srctype.fits",
                                                     "extract_2d.fits")
    model = datamodels.MultiSlitModel(extract2d_wcs_file)

    if writefile:
        # create the fits list to hold the calculated pathloss values for
        # each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create the fits list to hold the image of pipeline-calculated
        # difference values
        hdu0 = fits.PrimaryHDU()
        compfile = fits.HDUList()
        compfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    # loop over the slits
    sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"]
    msg = "Now looping through the slits. This may take a while... "
    print(msg)
    log_msgs.append(msg)
    if det == "NRS2":
        sltname_list.append("S200B1")

    # but check if data is BOTS
    if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
        sltname_list = ["S1600A1"]

    # get all the science extensions
    ps_uni_ext_list = get_ps_uni_extensions(reffile, is_point_source)

    # get files
    print("""Checking if files exist & obtaining datamodels.
          This takes a few minutes...""")
    if os.path.isfile(comparison_filename):
        if debug:
            print('Comparison file does exist.')
    else:
        result_msg = 'Comparison file does NOT exist. Skipping pathloss test.'
        print(result_msg)
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the comparison data model
    pathloss_pipe = datamodels.open(comparison_filename)
    # For the moment, the pipeline is using the wrong reference file for slit 400A1, so read file that
    # re-processed with the right reference file and open corresponding data model
    pathloss_400a1 = step_input_filename.replace("srctype.fits",
                                                 "pathloss_400A1.fits")
    pathloss_pipe_400a1 = datamodels.open(pathloss_400a1)
    if debug:
        print('got comparison datamodel!')

    if os.path.isfile(step_input_filename):
        if debug:
            print('Input file does exist.')
    else:
        result_msg = 'Input file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs
    # get the input data model
    pl = datamodels.open(step_input_filename)
    if debug:
        print('got input datamodel!')

    # loop through the wavelengths
    msg = "Looping through the wavelengths... "
    print(msg)
    log_msgs.append(msg)

    slit_val = 0
    for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits):
        slit_val = slit_val + 1

        slit_id = pipe_slit.name
        # with the current reference file, skip S400A1
        #if slit_id == "S400A1":
        #    continue
        print('\nWorking with slitlet ', slit_id)

        if slit.name == slit_id:
            msg = """Slitlet name in fits file previous to pathloss
            and in pathloss output file are the same.
            """
            log_msgs.append(msg)
            print(msg)
        else:
            msg = """* Missmatch of slitlet names in fits file previous to
            pathloss and in pathloss output file. Skipping test.
            """
            result = 'skip'
            log_msgs.append(msg)
            return result, msg, log_msgs

        # S-flat
        mode = "FS"

        if debug:
            print("grat = ", grat)

        continue_pl_test = False
        if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
            slit = model
            continue_pl_test = True
        else:
            for slit_in_MultiSlitModel in pl.slits:
                if slit_in_MultiSlitModel.name == slit_id:
                    slit = slit_in_MultiSlitModel
                    continue_pl_test = True
                    break

        if not continue_pl_test:
            continue
        else:
            try:
                if is_point_source is True:
                    ext = ps_uni_ext_list[0][slit_id]
                    print("Retrieved point source extension")
                elif is_point_source is False:
                    ext = ps_uni_ext_list[1][slit_id]
                    print("WARNING: Retrieved extended source extension")
            except KeyError:
                ext = sltname_list.index(slit_id)
                print("Unable to retrieve extension.")

        wcs_obj = slit.meta.wcs

        # get the wavelength
        x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, wave = wcs_obj(x, y)  # wave is in microns
        wave_sci = wave * 10**(-6)  # microns --> meters

        # get positions of source in file:
        slit_x = slit.source_xpos
        slit_y = slit.source_ypos
        if debug:
            print("slit_x, slit_y", slit_x, slit_y)

        if slit_id == "S400A1":
            if is_point_source:
                ext = 1
            else:
                ext = 3
            reffile2use = "jwst-nirspec-a400.plrf.fits"
        else:
            reffile2use = reffile

        msg = "Using reference file: " + reffile2use
        print(msg)
        log_msgs.append(msg)

        plcor_ref_ext = fits.getdata(reffile2use, ext)
        if debug:
            print("ext:", ext)
        hdul = fits.open(reffile2use)
        plcor_ref = hdul[1].data
        w = wcs.WCS(hdul[1].header)

        # make cube
        w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref.
                              shape[1], :plcor_ref.shape[2]]
        slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0)

        previous_sci = slit.data
        pipe_correction = pipe_slit.pathloss
        if slit_id == "S400A1":
            for pipe_slit_400a1 in pathloss_pipe_400a1.slits:
                if pipe_slit_400a1.name == "S400A1":
                    pipe_correction = pipe_slit_400a1.pathloss
                    break
                else:
                    continue
        if len(pipe_correction) == 0:
            print(
                "Pipeline pathloss correction in datamodel is empty. Skipping testing this slit."
            )
            continue

        # Set up manually to test correction at nonzero point
        # slit_x = 0.2
        # slit_y = 0.2
        # if debug:
        #     print("""WARNING: Using manually set slit_x and slit_y!
        # The pipeline correction will not use manually set values and
        # thus the residuals will change
        # """)

        correction_array = np.array([])
        lambda_array = np.array([])

        wave_sci_flat = wave_sci.reshape(wave_sci.size)
        wave_ref_flat = wave_ref.reshape(wave_ref.size)

        ref_xy = np.column_stack((slitx_ref.reshape(slitx_ref.size),
                                  slity_ref.reshape(slitx_ref.size)))

        # loop through slices in lambda from reference file
        shape = 0
        for lambda_val in wave_ref_flat:
            # loop through every lambda value
            # flattened so that looping works smoothly
            shape = shape + 1
            index = np.where(wave_ref[:, 0, 0] == lambda_val)
            # index of closest lambda value in reffile to given sci lambda
            #   took index of only the first slice of wave_ref because
            #   the others were repetitive & we got extra indices
            # take slice where lambda=index:
            plcor_slice = plcor_ref_ext[index[0][0]].reshape(
                plcor_ref_ext[index[0][0]].size)
            # do 2d interpolation to get a single correction factor for each slice
            corr_val = scipy.interpolate.griddata(ref_xy[:plcor_slice.size],
                                                  plcor_slice,
                                                  np.asarray([slit_x, slit_y]),
                                                  method='linear')
            # append values from loop to create a vector of correction factors
            correction_array = np.append(correction_array, corr_val[0])
            # map to array with corresponding lambda
            lambda_array = np.append(lambda_array, lambda_val)

        # get correction value for each pixel
        corr_vals = np.interp(wave_sci_flat, lambda_array, correction_array)
        corr_vals = corr_vals.reshape(wave_sci.shape)
        corrected_array = previous_sci / corr_vals

        # set up generals for all the plots
        font = {'weight': 'normal', 'size': 7}
        matplotlib.rc('font', **font)

        # Plots:
        step_input_filepath = step_input_filename.replace(".fits", "")
        # my correction values
        fig = plt.figure()
        plt.subplot(221)
        norm = ImageNormalize(corr_vals)
        plt.imshow(corr_vals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Calculated Correction')
        plt.colorbar()
        # pipe correction
        plt.subplot(222)
        norm = ImageNormalize(pipe_correction)
        plt.imshow(pipe_correction,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title("Pipeline Correction")
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        # residuals (pipe correction - my correction)
        corr_residuals = pipe_correction - corr_vals
        plt.subplot(223)
        norm = ImageNormalize(corr_residuals)
        plt.imshow(corr_residuals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Correction residuals')
        plt.colorbar()
        # my science data after
        plt.subplot(224)
        norm = ImageNormalize(corrected_array)
        plt.imshow(corrected_array,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('My science data after pathloss')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        fig.suptitle("FS PS Pathloss Correction Test for slit " + str(slit_id))

        if save_figs:
            plt_name = step_input_filepath + "_Pathloss_test_slitlet_"+str(mode) + "_" + str(slit_id) + "_" + \
                       str(slit_val) + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        plt.clf()

        # create fits file to hold the calculated pathloss for each slit
        if writefile:
            msg = "Saving the fits files with the calculated pathloss for each slit..."
            print(msg)
            log_msgs.append(msg)

            # this is the file to hold the image of pipeline-calculated difference values
            outfile_ext = fits.ImageHDU(corr_vals, name=slit_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values
            compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id)
            compfile.append(compfile_ext)

        if corr_residuals[~np.isnan(corr_residuals)].size == 0:
            msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \
                   "Test will be set to FAILED."
            print(msg1)
            log_msgs.append(msg1)
            test_result = "FAILED"
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            corr_residuals = corr_residuals[
                np.where((corr_residuals != 999.0) & (corr_residuals < 0.1)
                         & (corr_residuals > -0.1))]  # ignore outliers
            if corr_residuals.size == 0:
                msg1 = " * Unable to calculate statistics because difference array has all outlier values. " \
                       "Test will be set to FAILED."
                print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
            else:
                stats_and_strings = auxfunc.print_stats(corr_residuals,
                                                        "Difference",
                                                        float(threshold_diff),
                                                        abs=True)
                stats, stats_print_strings = stats_and_strings
                corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats
                for msg in stats_print_strings:
                    log_msgs.append(msg)

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(corr_residuals_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

        msg = " *** Result of the test: " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result.append(test_result)

    if writefile:
        outfile_name = step_input_filename.replace("srctype",
                                                   "calcuated_FS_PS_pathloss")
        compfile_name = step_input_filename.replace(
            "srctype", "comparison_FS_PS_pathloss")

        # create the fits list to hold the calculated flat values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        compfile.writeto(compfile_name, overwrite=True)

        msg = "\nFits file with calculated pathloss values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(compfile_name)
        log_msgs.append(compfile_name)

    # If all tests passed then pytest is marked as PASSED, else it is FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final pathloss test result reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED path_loss test."
    else:
        msg = "\n *** Final pathloss test result reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED path_loss test."

    # end the timer
    pathloss_end_time = time.time() - pathtest_start_time
    if pathloss_end_time > 60.0:
        pathloss_end_time = pathloss_end_time / 60.0  # in minutes
        pathloss_tot_time = "* Script FS_PS.py took ", repr(
            pathloss_end_time) + " minutes to finish."
        if pathloss_end_time > 60.0:
            pathloss_end_time = pathloss_end_time / 60.  # in hours
            pathloss_tot_time = "* Script FS_PS.py took ", repr(
                pathloss_end_time) + " hours to finish."
    else:
        pathloss_tot_time = "* Script FS_PS.py took ", repr(
            pathloss_end_time) + " seconds to finish."
    print(pathloss_tot_time)
    log_msgs.append(pathloss_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
예제 #8
0
def compare_wcs(infile_name,
                esa_files_path,
                msa_conf_name,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                mode_used=None,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the ESA intermediary files.

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        esa_files_path: str, full path of where to find all ESA intermediary products to make comparisons for the tests
        msa_conf_name: str, full path where to find the shutter configuration file
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and ESA file
        mode_used: string, mode used in the PTT configuration file
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements captured in this variable

    """

    log_msgs = []

    # get grating and filter info from the rate file header
    if mode_used is not None and mode_used == "MOS_sim":
        infile_name = infile_name.replace("assign_wcs", "extract_2d")
    msg = 'wcs validation test infile_name= ' + infile_name
    print(msg)
    log_msgs.append(msg)
    det = fits.getval(infile_name, "DETECTOR", 0)
    lamp = fits.getval(infile_name, "LAMP", 0)
    grat = fits.getval(infile_name, "GRATING", 0)
    filt = fits.getval(infile_name, "FILTER", 0)
    msametfl = fits.getval(infile_name, "MSAMETFL", 0)
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)

    # check that shutter configuration file in header is the same as given in PTT_config file
    if msametfl != os.path.basename(msa_conf_name):
        msg = "* WARNING! MSA config file name given in PTT_config file does not match the MSAMETFL keyword in main header.\n"
        print(msg)
        log_msgs.append(msg)

    # copy the MSA shutter configuration file into the pytest directory
    try:
        subprocess.run(["cp", msa_conf_name, "."])
    except FileNotFoundError:
        msg1 = " * PTT is not able to locate the MSA shutter configuration file. Please make sure that the msa_conf_name variable in"
        msg2 = "   the PTT_config.cfg file is pointing exactly to where the fits file exists (i.e. full path and name). "
        msg3 = "   -> The WCS test is now set to skip and no plots will be generated. "
        print(msg1)
        print(msg2)
        print(msg3)
        log_msgs.append(msg1)
        log_msgs.append(msg2)
        log_msgs.append(msg3)
        FINAL_TEST_RESULT = "skip"
        return FINAL_TEST_RESULT, log_msgs

    # get shutter info from metadata
    shutter_info = fits.getdata(
        msa_conf_name, extname="SHUTTER_INFO")  # this is generally ext=2
    pslit = shutter_info.field("slitlet_id")
    quad = shutter_info.field("shutter_quadrant")
    row = shutter_info.field("shutter_row")
    col = shutter_info.field("shutter_column")
    msg = 'Using this MSA shutter configuration file: ' + msa_conf_name
    print(msg)
    log_msgs.append(msg)

    # get the datamodel from the assign_wcs output file
    if mode_used is None or mode_used != "MOS_sim":
        img = datamodels.ImageModel(infile_name)
        # these commands only work for the assign_wcs ouput file
        # loop over the slits
        #slits_list = nirspec.get_open_slits(img)   # this function returns all open slitlets as defined in msa meta file,
        # however, some of them may not be projected on the detector, and those are later removed from the list of open
        # slitlets. To get the open and projected on the detector slitlets we use the following:
        slits_list = img.meta.wcs.get_transform('gwa', 'slit_frame').slits
        #print ('Open slits: ', slits_list, '\n')

        if debug:
            print("Instrument Configuration")
            print("Detector: {}".format(img.meta.instrument.detector))
            print("GWA: {}".format(img.meta.instrument.grating))
            print("Filter: {}".format(img.meta.instrument.filter))
            print("Lamp: {}".format(img.meta.instrument.lamp_state))
            print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt))
            print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt))
            print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt))

    elif mode_used == "MOS_sim":
        # this command works for the extract_2d and flat_field output files
        model = datamodels.MultiSlitModel(infile_name)
        slits_list = model.slits

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    # loop over the slices
    for slit in slits_list:
        name = slit.name
        msg = "\nWorking with slit: " + str(name)
        print(msg)
        log_msgs.append(msg)

        # get the right index in the list of open shutters
        pslit_list = pslit.tolist()
        slitlet_idx = pslit_list.index(int(name))

        # Get the ESA trace
        #raw_data_root_file = "NRSV96215001001P0000000002103_1_491_SE_2016-01-24T01h25m07.cts.fits" # testing only
        _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
        )
        msg = "Using this raw data file to find the corresponding ESA file: " + raw_data_root_file
        print(msg)
        log_msgs.append(msg)
        q, r, c = quad[slitlet_idx], row[slitlet_idx], col[slitlet_idx]
        msg = "Pipeline shutter info:   quadrant= " + str(
            q) + "   row= " + str(r) + "   col=" + str(c)
        print(msg)
        log_msgs.append(msg)
        specifics = [q, r, c]
        esafile = auxfunc.get_esafile(esa_files_path, raw_data_root_file,
                                      "MOS", specifics)
        #esafile = "/Users/pena/Documents/PyCharmProjects/nirspec/pipeline/src/sandbox/zzzz/Trace_MOS_3_319_013_V96215001001P0000000002103_41543_JLAB88.fits"  # testing only

        # skip the test if the esafile was not found
        if "ESA file not found" in esafile:
            msg1 = " * compare_wcs_mos.py is exiting because the corresponding ESA file was not found."
            msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
            print(msg1)
            print(msg2)
            log_msgs.append(msg1)
            log_msgs.append(msg2)
            FINAL_TEST_RESULT = "skip"
            return FINAL_TEST_RESULT, log_msgs

        # Open the trace in the esafile
        if len(esafile) == 2:
            print(len(esafile[-1]))
            if len(esafile[-1]) == 0:
                esafile = esafile[0]
        msg = "Using this ESA file: \n" + str(esafile)
        print(msg)
        log_msgs.append(msg)
        with fits.open(esafile) as esahdulist:
            print("* ESA file contents ")
            esahdulist.info()
            esa_shutter_i = esahdulist[0].header['SHUTTERI']
            esa_shutter_j = esahdulist[0].header['SHUTTERJ']
            esa_quadrant = esahdulist[0].header['QUADRANT']
            if debug:
                msg = "ESA shutter info:   quadrant=" + esa_quadrant + "   shutter_i=" + esa_shutter_i + "   shutter_j=" + esa_shutter_j
                print(msg)
                log_msgs.append(msg)
            # first check if ESA shutter info is the same as pipeline
            msg = "For slitlet" + str(name)
            print(msg)
            log_msgs.append(msg)
            if q == esa_quadrant:
                msg = "\n -> Same quadrant for pipeline and ESA data: " + str(
                    q)
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of quadrant for pipeline and ESA data: " + str(
                    q) + esa_quadrant
                print(msg)
                log_msgs.append(msg)
            if r == esa_shutter_i:
                msg = "\n -> Same row for pipeline and ESA data: " + str(r)
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of row for pipeline and ESA data: " + str(
                    r) + esa_shutter_i
                print(msg)
                log_msgs.append(msg)
            if c == esa_shutter_j:
                msg = "\n -> Same column for pipeline and ESA data: " + str(
                    c) + "\n"
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of column for pipeline and ESA data: " + str(
                    c) + esa_shutter_j + "\n"
                print(msg)
                log_msgs.append(msg)

            # Assign variables according to detector
            skipv2v3test = True
            if det == "NRS1":
                try:
                    esa_flux = fits.getdata(esafile, "DATA1")
                    esa_wave = fits.getdata(esafile, "LAMBDA1")
                    esa_slity = fits.getdata(esafile, "SLITY1")
                    esa_msax = fits.getdata(esafile, "MSAX1")
                    esa_msay = fits.getdata(esafile, "MSAY1")
                    pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X1")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y1")
                        skipv2v3test = False
                    except KeyError:
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except KeyError:
                    msg = "PTT did not find ESA extensions that match detector NRS1, skipping test for this slitlet..."
                    print(msg)
                    log_msgs.append(msg)
                    continue

            if det == "NRS2":
                try:
                    esa_flux = fits.getdata(esafile, "DATA2")
                    esa_wave = fits.getdata(esafile, "LAMBDA2")
                    esa_slity = fits.getdata(esafile, "SLITY2")
                    esa_msax = fits.getdata(esafile, "MSAX2")
                    esa_msay = fits.getdata(esafile, "MSAY2")
                    pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X2")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y2")
                        skipv2v3test = False
                    except KeyError:
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except KeyError:
                    msg = "PTT did not find ESA extensions that match detector NRS2, skipping test for this slitlet..."
                    print(msg)
                    log_msgs.append(msg)
                    continue

        # get the WCS object for this particular slit
        if mode_used is None or mode_used != "MOS_sim":
            try:
                wcs_slice = nirspec.nrs_wcs_set_input(img, name)
            except:
                ValueError
                msg = "* WARNING: Slitlet " + name + " was not found in the model. Skipping test for this slitlet."
                print(msg)
                log_msgs.append(msg)
                continue
        elif mode_used == "MOS_sim":
            wcs_slice = model.slits[0].wcs

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slice)

        # The WCS object attribute bounding_box shows all valid inputs, i.e. the actual area of the data according
        # to the slice. Inputs outside of the bounding_box return NaN values.
        #bbox = wcs_slice.bounding_box
        #print('wcs_slice.bounding_box: ', wcs_slice.bounding_box)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        #print("Avalable frames: ", wcs_slice.available_frames)

        if mode_used is None or mode_used != "MOS_sim":
            if debug:
                # To get specific pixel values use following syntax:
                det2slit = wcs_slice.get_transform('detector', 'slit_frame')
                slitx, slity, lam = det2slit(700, 1080)
                print("slitx: ", slitx)
                print("slity: ", slity)
                print("lambda: ", lam)

            if debug:
                # The number of inputs and outputs in each frame can vary. This can be checked with:
                print('Number on inputs: ', det2slit.n_inputs)
                print('Number on outputs: ', det2slit.n_outputs)

        # Create x, y indices using the Trace WCS
        pipey, pipex = np.mgrid[:esa_wave.shape[0], :esa_wave.shape[1]]
        esax, esay = pyw.all_pix2world(pipex, pipey, 0)

        if det == "NRS2":
            msg = "NRS2 needs a flip"
            print(msg)
            log_msgs.append(msg)
            esax = 2049 - esax
            esay = 2049 - esay

        # Compute pipeline RA, DEC, and lambda
        slitlet_test_result_list = []
        pra, pdec, pwave = wcs_slice(
            esax - 1, esay - 1
        )  # => RETURNS: RA, DEC, LAMBDA (lam *= 10**-6 to convert to microns)
        pwave *= 10**-6
        # calculate and print statistics for slit-y and x relative differences
        slitlet_name = repr(r) + "_" + repr(c)
        tested_quantity = "Wavelength Difference"
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats_strings = rel_diff_pwave_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pwave_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # get the transforms for pipeline slit-y
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(esax - 1, esay - 1)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff,
            esa_slity,
            esa_slity,
            slity,
            tested_quantity,
            abs=False)
        # calculate and print statistics for slit-y and x absolute differences
        #rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=True)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats_strings = rel_diff_pslity_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pslity_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slice.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(
            esax - 1, esay - 1
        )  # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats_strings = reldiffpmsax_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsax_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats_strings = reldiffpmsay_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsay_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # V2 and V3
        if not skipv2v3test:
            detector2v2v3 = wcs_slice.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(
                esax - 1, esay - 1
            )  # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            # converting to degrees to compare with ESA
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats_strings = reldiffpv2_data
            for msg in print_stats_strings:
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv2_stats[1],
                                                  threshold_diff)
            msg = 'Result for test of ' + tested_quantity + ': ' + result
            print(msg)
            log_msgs.append(msg)
            slitlet_test_result_list.append({tested_quantity: result})
            tested_quantity = "V3 difference"
            # converting to degrees to compare with ESA
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats_strings = reldiffpv3_data
            for msg in print_stats_strings:
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv3_stats[1],
                                                  threshold_diff)
            msg = 'Result for test of ' + tested_quantity + ': ' + result
            print(msg)
            log_msgs.append(msg)
            slitlet_test_result_list.append({tested_quantity: result})

        total_test_result[slitlet_name] = slitlet_test_result_list

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            basenameinfile_name = os.path.basename(infile_name)
            main_title = filt + "   " + grat + "   SLITLET=" + slitlet_name + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate them
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{ESA}) / \lambda_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_wave_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{ESA}$)/slit_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_slitY_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{ESA}$)/MSA_x$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_MSAx_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{ESA}$)/MSA_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_MSAy_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{ESA}$)/V2$_{ESA}$", "N"
                hist_data = notnan_reldiffpv2
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        slitlet_name + "_" + det + "_rel_V2_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{ESA}$)/V3$_{ESA}$", "N"
                hist_data = notnan_reldiffpv3
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        slitlet_name + "_" + det + "_rel_V3_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # remove the copy of the MSA shutter configuration file
    subprocess.run(["rm", msametfl])

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testlist in total_test_result.items():
        for tdict in testlist:
            for t, tr in tdict.items():
                if tr == "FAILED":
                    FINAL_TEST_RESULT = "FAILED"
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  FAILED."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    FINAL_TEST_RESULT = "PASSED"
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  PASSED."
                    print(msg)
                    log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
예제 #9
0
def do_correction_fixedslit(data,
                            pathloss,
                            inverse=False,
                            source_type=None,
                            correction_pars=None):
    """Path loss correction for NIRSpec fixed-slit modes

    Data is modified in-place.

    Parameters
    ----------
    data : jwst.datamodel.DataModel
        The NIRSpec fixed-slit data to be corrected.

    pathloss : jwst.datamodel.DataModel
        The pathloss reference data.

    inverse : boolean
        Invert the math operations used to apply the flat field.

    source_type : str or None
        Force processing using the specified source type.

    correction_pars : jwst.datamodels.MultiSlitModel or None
        The precomputed pathloss to apply instead of recalculation.

    Returns
    -------
    corrections : jwst.datamodel.MultiSlitModel
        The pathloss corrections applied.
    """
    exp_type = data.meta.exposure.type

    # Loop over all slits contained in the input
    corrections = datamodels.MultiSlitModel()
    for slit_number, slit in enumerate(data.slits):
        log.info(f'Working on slit {slit.name}')

        if correction_pars:
            correction = correction_pars.slits[slit_number]
        else:
            correction = _corrections_for_fixedslit(slit, pathloss, exp_type,
                                                    source_type)
        corrections.slits.append(correction)

        # Apply the correction
        if not correction:
            log.warning(
                f'No correction provided for slit {slit_number}. Skipping')
            continue

        if not inverse:
            slit.data /= correction.data
        else:
            slit.data *= correction.data
        slit.err /= correction.data
        slit.var_poisson /= correction.data**2
        slit.var_rnoise /= correction.data**2
        if slit.var_flat is not None and np.size(slit.var_flat) > 0:
            slit.var_flat /= correction.data**2
        slit.pathloss_point = correction.pathloss_point
        slit.pathloss_uniform = correction.pathloss_uniform

        slit.data /= correction.data
        slit.err /= correction.data
        slit.var_poisson /= correction.data**2
        slit.var_rnoise /= correction.data**2
        if slit.var_flat is not None and np.size(slit.var_flat) > 0:
            slit.var_flat /= correction.data**2
        slit.pathloss_point = correction.pathloss_point
        slit.pathloss_uniform = correction.pathloss_uniform

    # Set step status to complete
    data.meta.cal_step.pathloss = 'COMPLETE'

    return corrections
예제 #10
0
def compare_wcs(infile_name,
                truth_file=None,
                esa_files_path=None,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                raw_data_root_file=None,
                output_directory=None,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the truth files or the ESA intermediary files (to create new truth files).

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        truth_file: str, full path to the 'truth' (or benchmark) file to compare to
        esa_files_path: str or None, path to file all the ESA intermediary files
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and truth file
        raw_data_root_file: None or string, name of the raw file that produced the _uncal.fits file for caldetector1
        output_directory: None or string, path to the output_directory where to save the plots
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements captured in this variable

    """

    log_msgs = []

    if truth_file is not None:
        # get the model from the "truth" (or comparison) file
        truth_hdul = fits.open(truth_file)
        print("Information from the 'truth' (or comparison) file ")
        print(truth_hdul.info())
        truth_hdul.close()
        truth_img = datamodels.ImageModel(truth_file)
        # get the open slits in the truth file
        open_slits_truth = truth_img.meta.wcs.get_transform(
            'gwa', 'slit_frame').slits
        # open the datamodel
        truth_model = datamodels.MultiSlitModel(truth_file)
        # determine to what are we comparing to
        esa_files_path = None
        print("Comparing to ST 'truth' file.")
    else:
        print("Comparing to ESA data")

    # get grating and filter info from the rate file header
    if isinstance(infile_name, str):
        # get the datamodel from the assign_wcs output file
        img = datamodels.ImageModel(infile_name)
        msg = 'infile_name=' + infile_name
        print(msg)
        log_msgs.append(msg)
        basenameinfile_name = os.path.basename(infile_name)
    else:
        img = infile_name
        basenameinfile_name = ""

    det = img.meta.instrument.detector
    lamp = img.meta.instrument.lamp_state
    grat = img.meta.instrument.grating
    filt = img.meta.instrument.filter
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + \
          filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)
    print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt))
    print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt))
    print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt))

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    if esa_files_path is not None:
        # mapping the ESA slit names to pipeline names
        map_slit_names = {
            'SLIT_A_1600': 'S1600A1',
            'SLIT_A_200_1': 'S200A1',
            'SLIT_A_200_2': 'S200A2',
            'SLIT_A_400': 'S400A1',
            'SLIT_B_200': 'S200B1',
        }

    # To get the open and projected on the detector slits of the pipeline processed file
    open_slits = img.meta.wcs.get_transform('gwa', 'slit_frame').slits
    for opslit in open_slits:
        pipeslit = opslit.name
        msg = "\nWorking with slit: " + pipeslit
        print(msg)
        log_msgs.append(msg)

        compare_to_esa_data = False
        if esa_files_path is not None:
            compare_to_esa_data = True

            # Get the ESA trace
            if raw_data_root_file is None:
                _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
                    infile_name)
            specifics = [pipeslit]

            # check if ESA data is not in the regular directory tree, these files are exceptions
            NIDs = ["30055", "30055", "30205", "30133", "30133"]
            special_cutout_files = [
                "NRSSMOS-MOD-G1H-02-5344031756_1_491_SE_2015-12-10T03h25m56.fits",
                "NRSSMOS-MOD-G1H-02-5344031756_1_492_SE_2015-12-10T03h25m56.fits",
                "NRSSMOS-MOD-G2M-01-5344191938_1_491_SE_2015-12-10T19h29m26.fits",
                "NRSSMOS-MOD-G3H-02-5344120942_1_491_SE_2015-12-10T12h18m25.fits",
                "NRSSMOS-MOD-G3H-02-5344120942_1_492_SE_2015-12-10T12h18m25.fits"
            ]
            if raw_data_root_file in special_cutout_files:
                nid = NIDs[special_cutout_files.index(raw_data_root_file)]
                msg = "Using NID = " + nid
                print(msg)
                log_msgs.append(msg)
            else:
                nid = None

            esafile, esafile_log_msgs = auxfunc.get_esafile(esa_files_path,
                                                            raw_data_root_file,
                                                            "FS",
                                                            specifics,
                                                            nid=nid)
            for msg in esafile_log_msgs:
                log_msgs.append(msg)

            # skip the test if the esafile was not found
            if esafile == "ESA file not found":
                msg1 = " * compare_wcs_fs.py is exiting because the corresponding ESA file was not found."
                msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
                print(msg1)
                print(msg2)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
                FINAL_TEST_RESULT = "skip"
                return FINAL_TEST_RESULT, log_msgs
            """
            # comparison of filter, grating and x and y tilt
            gwa_xtil = fits.getval(infile_name, "gwa_xtil", 0)
            gwa_ytil = fits.getval(infile_name, "gwa_ytil", 0)
            esagrat = fits.getval(esafile, "GWA_POS", 0)
            esafilt = fits.getval(esafile, "FWA_POS", 0)
            esa_xtil = fits.getval(esafile, "gwa_xtil", 0)
            esa_ytil = fits.getval(esafile, "gwa_ytil", 0)
            print("pipeline: ")
            print("grating=", grat, " Filter=", filt, " gwa_xtil=", gwa_xtil, " gwa_ytil=", gwa_ytil)
            print("ESA:")
            print("grating=", esagrat, " Filter=", esafilt, " gwa_xtil=", esa_xtil, " gwa_ytil=", esa_ytil)
            """

            # Open the trace in the esafile
            msg = "Using this ESA file: \n" + esafile
            print(msg)
            log_msgs.append(msg)
            with fits.open(esafile) as esahdulist:
                print("* ESA file contents ")
                esahdulist.info()
                esa_slit_id = map_slit_names[esahdulist[0].header['SLITID']]
                # first check is esa_slit == to pipe_slit?
                if pipeslit == esa_slit_id:
                    msg = "\n -> Same slit found for pipeline and ESA data: " + pipeslit + "\n"
                    print(msg)
                    log_msgs.append(msg)
                else:
                    msg = "\n -> Missmatch of slits for pipeline and ESA data: " + pipeslit, esa_slit_id + "\n"
                    print(msg)
                    log_msgs.append(msg)

                # Assign variables according to detector
                skipv2v3test = True
                if det == "NRS1":
                    try:
                        truth_flux = fits.getdata(esafile, "DATA1")
                        truth_wave = fits.getdata(esafile, "LAMBDA1")
                        truth_slity = fits.getdata(esafile, "SLITY1")
                        truth_msax = fits.getdata(esafile, "MSAX1")
                        truth_msay = fits.getdata(esafile, "MSAY1")
                        pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                        try:
                            truth_v2 = fits.getdata(esafile, "V2V3X1")
                            truth_v3 = fits.getdata(esafile, "V2V3Y1")
                            skipv2v3test = False
                        except KeyError:
                            msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding " \
                                  "extensions."
                            print(msg)
                            log_msgs.append(msg)
                    except KeyError:
                        msg = "This file does not contain data for detector NRS1. Skipping test for this slit."
                        print(msg)
                        log_msgs.append(msg)
                        continue

                if det == "NRS2":
                    try:
                        truth_flux = fits.getdata(esafile, "DATA2")
                        truth_wave = fits.getdata(esafile, "LAMBDA2")
                        truth_slity = fits.getdata(esafile, "SLITY2")
                        truth_msax = fits.getdata(esafile, "MSAX2")
                        truth_msay = fits.getdata(esafile, "MSAY2")
                        pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                        try:
                            truth_v2 = fits.getdata(esafile, "V2V3X2")
                            truth_v3 = fits.getdata(esafile, "V2V3Y2")
                            skipv2v3test = False
                        except KeyError:
                            msg = "Skipping tests for V2 and V3 because ESA file does not contain " \
                                  "corresponding extensions."
                            print(msg)
                            log_msgs.append(msg)
                    except KeyError:
                        msg1 = "\n * compare_wcs_fs.py is exiting because there are no extensions that match " \
                               "detector NRS2 in the ESA file."
                        msg2 = "   -> The WCS test is now set to skip and no plots will be generated. \n"
                        print(msg1)
                        print(msg2)
                        log_msgs.append(msg1)
                        log_msgs.append(msg2)
                        FINAL_TEST_RESULT = "skip"
                        return FINAL_TEST_RESULT, log_msgs

            # Create x, y indices using the trace WCS from ESA
            pipey, pipex = np.mgrid[:truth_wave.shape[0], :truth_wave.shape[1]]
            esax, esay = pyw.all_pix2world(pipex, pipey, 0)

            if det == "NRS2":
                esax = 2049 - esax
                esay = 2049 - esay
                msg = "Flipped ESA data for detector NRS2 comparison with pipeline."
                print(msg)
                log_msgs.append(msg)

            # remove 1 to start from 0
            truth_x, truth_y = esax - 1, esay - 1

        # In case we are NOT comparing to ESA data
        if not compare_to_esa_data:
            # determine if the current open slit is also open in the truth file
            continue_with_wcs_test = False
            for truth_open_slit in open_slits_truth:
                if truth_open_slit.name == pipeslit:
                    continue_with_wcs_test = True
                    break

            if not continue_with_wcs_test:
                msg1 = "\n * Script compare_wcs_fs.py is exiting because open slit "+pipeslit+" is not open in " \
                                                                                              "truth file."
                msg2 = "   -> The WCS test is now set to FAILED and no plots will be generated. \n"
                print(msg1)
                print(msg2)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
                FINAL_TEST_RESULT = "FAILED"
                return FINAL_TEST_RESULT, log_msgs

            # get the WCS object for this particular truth slit
            slit_wcs = nirspec.nrs_wcs_set_input(truth_model, pipeslit)
            truth_x, truth_y = wcstools.grid_from_bounding_box(
                slit_wcs.bounding_box, step=(1, 1), center=True)
            truth_ra, truth_dec, truth_wave = slit_wcs(
                truth_x, truth_y)  # wave is in microns
            truth_wave *= 10**-6  # (lam *= 10**-6 to convert to microns)
            # get the truths to compare to
            truth_det2slit = slit_wcs.get_transform('detector', 'slit_frame')
            truth_slitx, truth_slity, _ = truth_det2slit(truth_x, truth_y)
            truth_det2slit = slit_wcs.get_transform("detector", "msa_frame")
            truth_msax, truth_msay, _ = truth_det2slit(truth_x, truth_y)
            truth_det2slit = slit_wcs.get_transform("detector", "v2v3")
            truth_v2, truth_v3, _ = truth_det2slit(truth_x, truth_y)
            skipv2v3test = False

        # get the WCS object for this particular slit
        wcs_slit = nirspec.nrs_wcs_set_input(img, pipeslit)

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slit)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        available_frames = wcs_slit.available_frames
        print("Avalable frames: ", available_frames)

        if debug:
            # To get specific pixel values use following syntax:
            det2slit = wcs_slit.get_transform('detector', 'slit_frame')
            slitx, slity, lam = det2slit(700, 1080)
            print("slitx: ", slitx)
            print("slity: ", slity)
            print("lambda: ", lam)

            # The number of inputs and outputs in each frame can vary. This can be checked with:
            print('Number on inputs: ', det2slit.n_inputs)
            print('Number on outputs: ', det2slit.n_outputs)

        # check if subarray is not FULL FRAME
        subarray = img.meta.subarray
        if "FULL" not in subarray:
            # In subarray coordinates
            # subtract xstart and ystart values in order to get subarray coords instead of full frame
            # wcs_slit.x(y)start are 1-based, turn them to 0-based for extraction
            xstart, ystart = img.meta.subarray.xstart, img.meta.subarray.ystart
            bounding_box = False
        else:
            bounding_box = True

        pra, pdec, pwave = wcs_slit(
            truth_x, truth_y,
            with_bounding_box=bounding_box)  # => RETURNS: RA, DEC, LAMBDA
        pwave *= 10**-6  # (lam *= 10**-6 to convert to microns)

        # calculate and print statistics for slit-y and x relative differences
        tested_quantity = "Wavelength Difference"
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats = rel_diff_pwave_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pwave_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # get the transforms for pipeline slit-y
        det2slit = wcs_slit.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(truth_x,
                                   truth_y,
                                   with_bounding_box=bounding_box)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff,
            truth_slity,
            truth_slity,
            slity,
            tested_quantity,
            absolute=False)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats = rel_diff_pslity_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pslity_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slit.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(truth_x,
                                       truth_y,
                                       with_bounding_box=bounding_box)
        # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats = reldiffpmsax_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsax_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats = reldiffpmsay_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsay_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # V2 and V3
        if not skipv2v3test and 'v2v3' in available_frames:
            detector2v2v3 = wcs_slit.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(truth_x,
                                        truth_y,
                                        with_bounding_box=bounding_box)
            # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, truth_slity, truth_v2, pv2, tested_quantity)
            # converting to degrees to compare with truth, pipeline is in arcsec
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with truth file"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, truth_slity, truth_v2, pv2,
                    tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats = reldiffpv2_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv2_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

            tested_quantity = "V3 difference"
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, truth_slity, truth_v3, pv3, tested_quantity)
            # converting to degrees to compare with truth, pipeline is in arcsec
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with truth file"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, truth_slity, truth_v3, pv3,
                    tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats = reldiffpv3_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv3_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            main_title = filt + "   " + grat + "   SLIT=" + pipeslit + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate number
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{truth}) / \lambda_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_wave_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = os.path.join(
                            os.getcwd(),
                            pipeslit + "_" + det + specific_plt_name)
                        print(
                            "No output_directory was provided. Figures will be saved in current working directory:"
                        )
                        print(plt_name + "\n")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{truth}$)/slit_y$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit position."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_slitY_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{truth}$)/MSA_x$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAx_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{truth}$)/MSA_y$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAy_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test and 'v2v3' in available_frames:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{truth}$)/V2$_{truth}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V2_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pipeslit + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 notnan_reldiffpv2_stats,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{truth}$)/V3$_{truth}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V3_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pipeslit + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 notnan_reldiffpv3,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testdir in total_test_result.items():
        for t, tr in testdir.items():
            if tr == "FAILED":
                FINAL_TEST_RESULT = "FAILED"
                msg = "\n * The test of " + t + " for slit " + sl + " FAILED."
                print(msg)
                log_msgs.append(msg)
            else:
                FINAL_TEST_RESULT = "PASSED"
                msg = "\n * The test of " + t + " for slit " + sl + " PASSED."
                print(msg)
                log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
    print(msg)
    log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
예제 #11
0
def do_correction(input_model,
                  barshadow_model=None,
                  inverse=False,
                  source_type=None,
                  correction_pars=None):
    """Do the Bar Shadow Correction

    Parameters
    ----------
    input_model : `~jwst.datamodels.MultiSlitModel`
        science data model to be corrected

    barshadow_model : `~jwst.datamodels.BarshadowModel`
        bar shadow data model from reference file

    inverse : boolean
        Invert the math operations used to apply the flat field.

    source_type : str or None
        Force processing using the specified source type.

    correction_pars : dict or None
        Correction parameters to use instead of recalculation.

    Returns
    -------
    output_model, corrections : `~jwst.datamodels.MultiSlitModel`, jwst.datamodels.DataModel
        Science data model with correction applied and barshadow extensions added,
        and a model of the correction arrays.
    """

    # Input is a MultiSlitModel science data model.
    # A MultislitModel has a member ".slits" that behaves like
    # a list of Slits, each of which has several data arrays and
    # keyword metadata items associated.
    #
    # At this point we're going to have to assume that a Slit is composed
    # of a set of slit[].nshutters in a vertical line.
    #
    # Reference file information is a 1x1 ref file and a 1x3 ref file.
    # Both the 1x1 and 1x3 ref files are 1001 pixels high and go from
    # -1 to 1 in their Y value WCS.

    exp_type = input_model.meta.exposure.type
    log.debug('EXP_TYPE = %s' % exp_type)

    # Create output as a copy of the input science data model
    output_model = input_model.copy()

    # Loop over all the slits in the input model
    corrections = datamodels.MultiSlitModel()
    for slit_idx, slitlet in enumerate(output_model.slits):
        slitlet_number = slitlet.slitlet_id
        log.info('Working on slitlet %d' % slitlet_number)

        if correction_pars:
            correction = correction_pars.slits[slit_idx]
        else:
            correction = _calc_correction(slitlet, barshadow_model,
                                          source_type)
        corrections.slits.append(correction)

        # Apply the correction by dividing into the science and uncertainty arrays:
        #     var_poisson and var_rnoise are divided by correction**2,
        #     because they're variance, while err is standard deviation
        if not inverse:
            slitlet.data /= correction.data
        else:
            slitlet.data *= correction.data
        slitlet.err /= correction.data
        slitlet.var_poisson /= correction.data**2
        slitlet.var_rnoise /= correction.data**2
        if slitlet.var_flat is not None and np.size(slitlet.var_flat) > 0:
            slitlet.var_flat /= correction.data**2
        slitlet.barshadow = correction.data

    return output_model, corrections
def flattest(step_input_filename,
             dflatref_path=None,
             sfile_path=None,
             fflat_path=None,
             msa_shutter_conf=None,
             writefile=False,
             show_figs=True,
             save_figs=False,
             plot_name=None,
             threshold_diff=1.0e-14,
             debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the
    compute_world_coordinates.py script with the ESA files. The function calls that script.

    Args:
        step_input_filename: str, name of the output fits file from the 2d_extract step (with full path)
        dflatref_path: str, path of where the D-flat reference fits files
        sfile_path: str, path of where the S-flat reference fits files
        fflat_path: str, path of where the F-flat reference fits files
        msa_shutter_conf: str, full path and name of the MSA configuration fits file
        writefile: boolean, if True writes the fits files of the calculated flat and difference images
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots (the 3 plots can be saved or not independently with the function call)
        plot_name: string, desired name (if name is not given, the plot function will name the plot by
                    default)
        threshold_diff: float, threshold difference between pipeline output and ESA file
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - 1 plot, if told to save and/or show.
        - median_diff: Boolean, True if smaller or equal to 1e-14
        - log_msgs: list, all print statements are captured in this variable

    """

    log_msgs = []

    # start the timer
    flattest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    lamp = fits.getval(step_input_filename, "LAMP", 0)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)
    msg = "rate_file  -->     Grating:" + grat + "   Filter:" + filt + "   Lamp:" + lamp
    print(msg)
    log_msgs.append(msg)

    # read in the on-the-fly flat image
    flatfile = step_input_filename.replace("flat_field.fits",
                                           "interpolatedflat.fits")

    # get the reference files
    # D-Flat
    dflat_ending = "f_01.03.fits"
    dfile = "_".join((dflatref_path, "nrs1", dflat_ending))
    if det == "NRS2":
        dfile = dfile.replace("nrs1", "nrs2")
    msg = "Using D-flat: " + dfile
    print(msg)
    log_msgs.append(msg)
    dfim = fits.getdata(dfile, "SCI")  #1)
    dfimdq = fits.getdata(dfile, "DQ")  #4)
    # need to flip/rotate the image into science orientation
    ns = np.shape(dfim)
    dfim = np.transpose(dfim, (
        0, 2,
        1))  # keep in mind that 0,1,2 = z,y,x in Python, whereas =x,y,z in IDL
    dfimdq = np.transpose(dfimdq)
    if det == "NRS2":
        # rotate science data by 180 degrees for NRS2
        dfim = dfim[..., ::-1, ::-1]
        dfimdq = dfimdq[..., ::-1, ::-1]
    naxis3 = fits.getval(dfile, "NAXIS3", "SCI")

    # get the wavelength values
    dfwave = np.array([])
    for i in range(naxis3):
        keyword = "_".join(("PFLAT", str(i + 1)))
        dfwave = np.append(dfwave, fits.getval(dfile, keyword, "SCI"))
    dfrqe = fits.getdata(dfile, 2)

    # S-flat
    tsp = exptype.split("_")
    mode = tsp[1]
    if filt == "F070LP":
        flat = "FLAT4"
    elif filt == "F100LP":
        flat = "FLAT1"
    elif filt == "F170LP":
        flat = "FLAT2"
    elif filt == "F290LP":
        flat = "FLAT3"
    elif filt == "CLEAR":
        flat = "FLAT5"
    else:
        msg = "No filter correspondence. Exiting the program."
        print(msg)
        log_msgs.append(msg)
        # This is the key argument for the assert pytest function
        result_msg = "Test skiped because there is no flat correspondance for the filter in the data: {}".format(
            filt)
        median_diff = "skip"
        return median_diff, result_msg, log_msgs

    sflat_ending = "f_01.01.fits"
    sfile = "_".join((sfile_path, grat, "OPAQUE", flat, "nrs1", sflat_ending))
    msg = "Using S-flat: " + sfile
    print(msg)
    log_msgs.append(msg)

    if det == "NRS2":
        sfile = sfile.replace("nrs1", "nrs2")
    sfim = fits.getdata(sfile, "SCI")  #1)
    sfimdq = fits.getdata(sfile, "DQ")  #3)

    # need to flip/rotate image into science orientation
    sfim = np.transpose(sfim, (0, 2, 1))
    sfimdq = np.transpose(sfimdq, (0, 2, 1))
    if det == "NRS2":
        # rotate science data by 180 degrees for NRS2
        sfim = sfim[..., ::-1, ::-1]
        sfimdq = sfimdq[..., ::-1, ::-1]

    # get the wavelength values for sflat cube
    sfimwave = np.array([])
    naxis3 = fits.getval(sfile, "NAXIS3", "SCI")
    for i in range(0, naxis3):
        if i + 1 < 10:
            keyword = "".join(("FLAT_0", str(i + 1)))
        else:
            keyword = "".join(("FLAT_", str(i + 1)))
        #print("S-flat -> using ", keyword)
        try:
            sfimwave = np.append(sfimwave, fits.getval(sfile, keyword, "SCI"))
        except:
            KeyError
    sfv = fits.getdata(sfile, 5)

    # F-Flat
    #print("F-flat -> using the following flats: ")
    fflat_ending = "_01.01.fits"
    ffile = fflat_path + "_" + filt + fflat_ending
    msg = "Using F-flat: " + ffile
    print(msg)
    log_msgs.append(msg)
    ffsq1 = fits.getdata(ffile, "SCI_Q1")  #1)
    naxis3 = fits.getval(ffile, "NAXIS3", "SCI_Q1")  #1)
    ffswaveq1 = np.array([])
    for i in range(0, naxis3):
        if i <= 9:
            suff = "".join(("0", str(i)))
        else:
            suff = str(i)
        t = ("FLAT", suff)
        keyword = "_".join(t)
        #print("1. F-flat -> ", keyword)
        ffswaveq1 = np.append(ffswaveq1, fits.getval(ffile, keyword, "SCI_Q1"))
    ffserrq1 = fits.getdata(ffile, "ERR_Q1")  #2)
    ffsdqq1 = fits.getdata(ffile, "DQ_Q1")  #3)
    ffvq1 = fits.getdata(ffile, "Q1")  #4)
    ffsq2 = fits.getdata(ffile, "SCI_Q2")
    ffswaveq2 = np.array([])
    for i in range(0, naxis3):
        if i <= 9:
            suff = "".join(("0", str(i)))
        else:
            suff = str(i)
        t = ("FLAT", suff)
        keyword = "_".join(t)
        #print("2. F-flat -> using ", keyword)
        ffswaveq2 = np.append(ffswaveq2, fits.getval(ffile, keyword, "SCI_Q2"))
    ffserrq2 = fits.getdata(ffile, "ERR_Q2")
    ffsdqq2 = fits.getdata(ffile, "DQ_Q2")
    ffvq2 = fits.getdata(ffile, "Q2")
    ffsq3 = fits.getdata(ffile, "SCI_Q3")
    ffswaveq3 = np.array([])
    for i in range(0, naxis3):
        if i <= 9:
            suff = "".join(("0", str(i)))
        else:
            suff = str(i)
        t = ("FLAT", suff)
        keyword = "_".join(t)
        #print("3. F-flat -> using ", keyword)
        ffswaveq3 = np.append(ffswaveq3, fits.getval(ffile, keyword, "SCI_Q3"))
    ffserrq3 = fits.getdata(ffile, "ERR_Q3")
    ffsdqq3 = fits.getdata(ffile, "DQ_Q3")
    ffvq3 = fits.getdata(ffile, "Q3")
    ffsq4 = fits.getdata(ffile, "SCI_Q4")
    ffswaveq4 = np.array([])
    for i in range(0, naxis3):
        if i <= 9:
            suff = "0" + str(i)
        else:
            suff = str(i)
        keyword = "FLAT_" + suff
        #print("4. F-flat -> using ", keyword)
        ffswaveq4 = np.append(ffswaveq4, fits.getval(ffile, keyword, "SCI_Q4"))
    ffserrq4 = fits.getdata(ffile, "ERR_Q4")
    ffsdqq4 = fits.getdata(ffile, "DQ_Q4")
    ffvq4 = fits.getdata(ffile, "Q4")

    # now go through each pixel in the test data

    if writefile:
        # create the fits list to hold the image of the correction values
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create the fits list to hold the image of the comparison values
        hdu0 = fits.PrimaryHDU()
        complfile = fits.HDUList()
        complfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    # get the datamodel from the assign_wcs output file
    extract2d_file = step_input_filename.replace("_flat_field.fits",
                                                 "_extract_2d.fits")
    model = datamodels.MultiSlitModel(extract2d_file)

    # get all the science extensions in the flatfile
    sci_ext_list = auxfunc.get_sci_extensions(flatfile)

    # loop over the 2D subwindows and read in the WCS values
    for slit in model.slits:
        slit_id = slit.name
        msg = "\nWorking with slit: " + slit_id
        print(msg)
        log_msgs.append(msg)
        ext = sci_ext_list[
            slit_id]  # this is for getting the science extension in the pipeline calculated flat
        # make sure that the slitlet is open and projected on the detector, otherwise indicate so
        #if not slit_id in open_and_on_detector_slits_list:
        #    print("* This open slitlet was removed because it is not projected on the detector. Test skipped for this slitlet. \n")
        #    continue

        # get the wavelength
        y, x = np.mgrid[:slit.data.shape[0], :slit.data.shape[1]]
        ra, dec, wave = slit.meta.wcs(x, y)  # wave is in microns

        # get the subwindow origin
        px0 = slit.xstart - 1 + model.meta.subarray.xstart
        py0 = slit.ystart - 1 + model.meta.subarray.ystart
        msg = " Subwindow origin:   px0=" + repr(px0) + "   py0=" + repr(py0)
        print(msg)
        log_msgs.append(msg)
        n_p = np.shape(wave)
        nw = n_p[0] * n_p[1]
        nw1, nw2 = n_p[1], n_p[
            0]  # remember that x=nw1 and y=nw2 are reversed  in Python
        if debug:
            print("nw = ", nw)

        delf = np.zeros([nw2, nw1]) + 999.0
        flatcor = np.zeros([nw2, nw1]) + 999.0

        # get the slitlet info, needed for the F-Flat
        ext_shutter_info = "SHUTTER_INFO"  # this is extension 2 of the msa file, that has the shutter info
        slitlet_info = fits.getdata(msa_shutter_conf, ext_shutter_info)
        sltid = slitlet_info.field("SLITLET_ID")
        for j, s in enumerate(sltid):
            if s == int(slit_id):
                im = j
                # get the shutter with the source in it
                if slitlet_info.field("BACKGROUND")[im] == "N":
                    isrc = j
        # changes suggested by Phil Hodge
        quad = slit.quadrant  #slitlet_info.field("SHUTTER_QUADRANT")[isrc]
        row = slit.xcen  #slitlet_info.field("SHUTTER_ROW")[isrc]
        col = slit.ycen  #slitlet_info.field("SHUTTER_COLUMN")[isrc]
        slitlet_id = repr(row) + "_" + repr(col)
        msg = 'silt_id=' + repr(slit_id) + "   quad=" + repr(
            quad) + "   row=" + repr(row) + "   col=" + repr(
                col) + "   slitlet_id=" + repr(slitlet_id)
        print(msg)
        log_msgs.append(msg)

        # get the relevant F-flat reference data
        if quad == 1:
            ffsall = ffsq1
            ffsallwave = ffswaveq1
            ffsalldq = ffsdqq1
            ffv = ffvq1
        if quad == 2:
            ffsall = ffsq2
            ffsallwave = ffswaveq2
            ffsalldq = ffsdqq2
            ffv = ffvq2
        if quad == 3:
            ffsall = ffsq3
            ffsallwave = ffswaveq3
            ffsalldq = ffsdqq3
            ffv = ffvq3
        if quad == 4:
            ffsall = ffsq4
            ffsallwave = ffswaveq4
            ffsalldq = ffsdqq4
            ffv = ffvq4

        # loop through the pixels
        msg = "Now looping through the pixels, this will take a while ... "
        print(msg)
        log_msgs.append(msg)
        wave_shape = np.shape(wave)
        for j in range(nw1):  # in x
            for k in range(nw2):  # in y
                if np.isfinite(wave[k, j]):  # skip if wavelength is NaN
                    # get the pixel indeces
                    jwav = wave[k, j]
                    pind = [k + py0 - 1, j + px0 - 1]
                    if debug:
                        print('j, k, jwav, px0, py0 : ', j, k, jwav, px0, py0)
                        print('pind = ', pind)

                    # get the pixel bandwidth
                    if (j != 0) and (j < nw1 - 1):
                        if np.isfinite(wave[k, j + 1]) and np.isfinite(
                                wave[k, j - 1]):
                            delw = 0.5 * (wave[k, j + 1] - wave[k, j - 1])
                        if np.isfinite(wave[k, j + 1]) and not np.isfinite(
                                wave[k, j - 1]):
                            delw = wave[k, j + 1] - wave[k, j]
                        if not np.isfinite(wave[k, j + 1]) and np.isfinite(
                                wave[k, j - 1]):
                            delw = wave[k, j] - wave[k, j - 1]
                    if j == 0:
                        delw = wave[k, j + 1] - wave[k, j]
                    if j == nw - 1:
                        delw = wave[k, j] - wave[k, j - 1]

                    if debug:
                        print("wave[k, j+1], wave[k, j-1] : ",
                              np.isfinite(wave[k, j + 1]), wave[k, j + 1],
                              wave[k, j - 1])
                        print("delw = ", delw)

                    # integrate over dflat fast vector
                    dfrqe_wav = dfrqe.field("WAVELENGTH")
                    dfrqe_rqe = dfrqe.field("RQE")
                    iw = np.where((dfrqe_wav >= wave[k, j] - delw / 2.0)
                                  & (dfrqe_wav <= wave[k, j] + delw / 2.0))
                    if np.size(iw) == 0:
                        dff = 1.0
                    else:
                        int_tab = auxfunc.idl_tabulate(dfrqe_wav[iw[0]],
                                                       dfrqe_rqe[iw[0]])
                        first_dfrqe_wav, last_dfrqe_wav = dfrqe_wav[
                            iw[0]][0], dfrqe_wav[iw[0]][-1]
                        dff = int_tab / (last_dfrqe_wav - first_dfrqe_wav)

                    if debug:
                        #print("np.shape(dfrqe_wav) : ", np.shape(dfrqe_wav))
                        #print("np.shape(dfrqe_rqe) : ", np.shape(dfrqe_rqe))
                        #print("dfimdq[pind[0]][pind[1]] : ", dfimdq[pind[0]][pind[1]])
                        #print("np.shape(iw) =", np.shape(iw))
                        #print("np.shape(dfrqe_wav[iw[0]]) = ", np.shape(dfrqe_wav[iw[0]]))
                        #print("np.shape(dfrqe_rqe[iw[0]]) = ", np.shape(dfrqe_rqe[iw[0]]))
                        #print("int_tab=", int_tab)
                        print("dff = ", dff)

                    # interpolate over dflat cube
                    iloc = auxfunc.idl_valuelocate(dfwave, wave[k, j])[0]
                    if dfwave[iloc] > wave[k, j]:
                        iloc -= 1
                    ibr = [iloc]
                    if iloc != len(dfwave) - 1:
                        ibr.append(iloc + 1)
                    # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0]
                    zz = dfim[:, pind[0], pind[1]][ibr]
                    # now determine the length of the array with only the finite numbers
                    zzwherenonan = np.where(np.isfinite(zz))
                    kk = np.size(zzwherenonan)
                    dfs = 1.0
                    if (wave[k, j] <= max(dfwave)) and (
                            wave[k, j] >= min(dfwave)) and (kk == 2):
                        dfs = np.interp(wave[k, j], dfwave[ibr],
                                        zz[zzwherenonan])
                    # check DQ flags
                    if dfimdq[pind[0], pind[1]] != 0:
                        dfs = 1.0

                    # integrate over S-flat fast vector
                    sfv_wav = sfv.field("WAVELENGTH")
                    sfv_dat = sfv.field("DATA")
                    iw = np.where((sfv_wav >= wave[k, j] - delw / 2.0)
                                  & (sfv_wav <= wave[k, j] + delw / 2.0))
                    sff = 1.0
                    if np.size(iw) > 2:
                        int_tab = auxfunc.idl_tabulate(sfv_wav[iw],
                                                       sfv_dat[iw])
                        first_sfv_wav, last_sfv_wav = sfv_wav[
                            iw[0]][0], sfv_wav[iw[0]][-1]
                        sff = int_tab / (last_sfv_wav - first_sfv_wav)

                    # interpolate s-flat cube
                    iloc = auxfunc.idl_valuelocate(sfimwave, wave[k, j])[0]
                    ibr = [iloc]
                    if iloc != len(sfimwave) - 1:
                        ibr.append(iloc + 1)
                    # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0]
                    zz = sfim[:, pind[0], pind[1]][ibr]
                    # now determine the length of the array with only the finite numbers
                    zzwherenonan = np.where(np.isfinite(zz))
                    kk = np.size(zzwherenonan)
                    sfs = 1.0
                    if (wave[k, j] <= max(sfimwave)) and (
                            wave[k, j] >= min(sfimwave)) and (kk == 2):
                        sfs = np.interp(wave[k, j], sfimwave[ibr],
                                        zz[zzwherenonan])

                    # check DQ flags
                    kk = np.where(sfimdq[:, pind[0], pind[1]][ibr] == 0)
                    if np.size(kk) != 2:
                        sfs = 1.0

                    # integrate over f-flat fast vector
                    # reference file wavelength range is from 0.6 to 5.206 microns, so need to force
                    # solution to 1 for wavelengths outside that range
                    ffv_wav = ffv.field("WAVELENGTH")
                    ffv_dat = ffv.field("DATA")
                    fff = 1.0
                    if (wave[k, j] - delw / 2.0 >=
                            0.6) and (wave[k, j] + delw / 2.0 <= 5.206):
                        iw = np.where((ffv_wav >= wave[k, j] - delw / 2.0)
                                      & (ffv_wav <= wave[k, j] + delw / 2.0))
                        if np.size(iw) > 1:
                            int_tab = auxfunc.idl_tabulate(
                                ffv_wav[iw], ffv_dat[iw])
                            first_ffv_wav, last_ffv_wav = ffv_wav[
                                iw[0]][0], ffv_wav[iw[0]][-1]
                            fff = int_tab / (last_ffv_wav - first_ffv_wav)

                    # interpolate over f-flat cube
                    ffs = np.interp(wave[k, j], ffsallwave, ffsall[:, col - 1,
                                                                   row - 1])

                    flatcor[k, j] = dff * dfs * sff * sfs * fff * ffs

                    if (pind[1] - px0 + 1 == 9999) and (pind[0] - py0 + 1
                                                        == 9999):
                        if debug:
                            print("pind = ", pind)
                            print("wave[k, j] = ", wave[k, j])
                            print("dfs, dff = ", dfs, dff)
                            print("sfs, sff = ", sfs, sff)

                        msg = "Making the plot fot this slitlet..."
                        print(msg)
                        log_msgs.append(msg)
                        # make plot
                        font = {  #'family' : 'normal',
                            'weight': 'normal',
                            'size': 16
                        }
                        matplotlib.rc('font', **font)
                        fig = plt.figure(1, figsize=(12, 10))
                        plt.subplots_adjust(hspace=.4)
                        ax = plt.subplot(111)
                        xmin = wave[k, j] - 0.01
                        xmax = wave[k, j] + 0.01
                        plt.xlim(xmin, xmax)
                        plt.plot(dfwave,
                                 dfim[:, pind[0], pind[1]],
                                 linewidth=7,
                                 marker='D',
                                 color='k',
                                 label="dflat_im")
                        plt.plot(wave[k, j],
                                 dfs,
                                 linewidth=7,
                                 marker='D',
                                 color='r')
                        plt.plot(dfrqe_wav,
                                 dfrqe_rqe,
                                 linewidth=7,
                                 marker='D',
                                 c='k',
                                 label="dflat_vec")
                        plt.plot(wave[k, j],
                                 dff,
                                 linewidth=7,
                                 marker='D',
                                 color='r')
                        plt.plot(sfimwave,
                                 sfim[:, pind[0], pind[1]],
                                 linewidth=7,
                                 marker='D',
                                 color='k',
                                 label="sflat_im")
                        plt.plot(wave[k, j],
                                 sfs,
                                 linewidth=7,
                                 marker='D',
                                 color='r')
                        plt.plot(sfv_wav,
                                 sfv_dat,
                                 linewidth=7,
                                 marker='D',
                                 color='k',
                                 label="sflat_vec")
                        plt.plot(wave[k, j],
                                 sff,
                                 linewidth=7,
                                 marker='D',
                                 color='r')
                        # add legend
                        box = ax.get_position()
                        ax.set_position(
                            [box.x0, box.y0, box.width * 1.0, box.height])
                        ax.legend(loc='upper right', bbox_to_anchor=(1, 1))
                        plt.minorticks_on()
                        plt.tick_params(axis='both',
                                        which='both',
                                        bottom=True,
                                        top=True,
                                        right=True,
                                        direction='in',
                                        labelbottom=True)
                        plt.show()
                        msg = "Exiting the program. Unable to calculate statistics. Test set to be SKIPPED."
                        print(msg)
                        log_msgs.append(msg)
                        plt.close()
                        result_msg = "Unable to calculate statistics. Test set be SKIP."
                        median_diff = "skip"
                        return median_diff, result_msg, log_msgs

                    if debug:
                        print("dfs = ", dfs)
                        print("sff = ", sff)
                        print("sfs = ", sfs)
                        print("ffs = ", ffs)

                    # read the pipeline-calculated flat image
                    # there are four extensions in the flatfile: SCI, DQ, ERR, WAVELENGTH
                    pipeflat = fits.getdata(flatfile, ext)

                    try:
                        # Difference between pipeline and calculated values
                        delf[k, j] = pipeflat[k, j] - flatcor[k, j]

                        # Remove all pixels with values=1 (outside slit boundaries) for statistics
                        if pipeflat[k, j] == 1:
                            delf[k, j] = 999.0
                        if np.isnan(wave[k, j]):
                            flatcor[k,
                                    j] = 1.0  # no correction if no wavelength

                        if debug:
                            print("flatcor[k, j] = ", flatcor[k, j])
                            print("delf[k, j] = ", delf[k, j])
                    except:
                        IndexError

        nanind = np.isnan(delf)  # get all the nan indexes
        notnan = ~nanind  # get all the not-nan indexes
        delf = delf[notnan]  # get rid of NaNs
        delf_shape = np.shape(delf)
        test_result = "FAILED"
        if delf.size == 0:
            msg1 = " * Unable to calculate statistics because difference array has all values as NaN."
            msg2 = "   Test will be set to FAILED and NO plots will be made."
            print(msg1)
            print(msg2)
            log_msgs.append(msg1)
            log_msgs.append(msg2)
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            delfg = delf[np.where((delf != 999.0) & (delf < 0.1)
                                  & (delf > -0.1))]  # ignore outliers
            if delfg.size == 0:
                msg1 = " * Unable to calculate statistics because difference array has all outlier values."
                msg2 = "   Test will be set to FAILED and NO plots will be made."
                print(msg1)
                print(msg2)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
            else:
                stats_and_strings = auxfunc.print_stats(delfg,
                                                        "Flat Difference",
                                                        float(threshold_diff),
                                                        abs=True)
                stats, stats_print_strings = stats_and_strings
                delfg_mean, delfg_median, delfg_std = stats

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(delfg_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

                if save_figs or show_figs:
                    # make histogram
                    msg = "Making histogram plot for this slitlet..."
                    print(msg)
                    log_msgs.append(msg)
                    # set the plot variables
                    main_title = filt + "   " + grat + "   SLIT=" + slit_id + "\n"
                    bins = None  # binning for the histograms, if None the function will select them automatically
                    #             lolim_x, uplim_x, lolim_y, uplim_y
                    plt_origin = None

                    # Residuals img and histogram
                    title = main_title + "Residuals"
                    info_img = [title, "x (pixels)", "y (pixels)"]
                    xlabel, ylabel = "flat$_{pipe}$ - flat$_{calc}$", "N"
                    info_hist = [xlabel, ylabel, bins, stats]
                    if delfg[1] is np.nan:
                        msg = "Unable to create plot of relative wavelength difference."
                        print(msg)
                        log_msgs.append(msg)
                    else:
                        file_path = step_input_filename.replace(
                            os.path.basename(step_input_filename), "")
                        file_basename = os.path.basename(
                            step_input_filename.replace(".fits", ""))
                        t = (file_basename,
                             "MOS_flattest_" + slitlet_id + "_histogram.pdf")
                        plt_name = "_".join(t)
                        plt_name = os.path.join(file_path, plt_name)
                        difference_img = (pipeflat - flatcor)  #/flatcor
                        in_slit = np.logical_and(
                            difference_img < 900.0,
                            difference_img > -900.0)  # ignore out of slitlet
                        difference_img[
                            ~in_slit] = np.nan  # Set values outside the slit to NaN
                        nanind = np.isnan(
                            difference_img)  # get all the nan indexes
                        difference_img[
                            nanind] = np.nan  # set all nan indexes to have a value of nan
                        vminmax = [
                            -5 * delfg_std, 5 * delfg_std
                        ]  # set the range of values to be shown in the image, will affect color scale
                        auxfunc.plt_two_2Dimgandhist(difference_img,
                                                     delfg,
                                                     info_img,
                                                     info_hist,
                                                     plt_name=plt_name,
                                                     vminmax=vminmax,
                                                     plt_origin=plt_origin,
                                                     show_figs=show_figs,
                                                     save_figs=save_figs)

                elif not save_figs and not show_figs:
                    msg = "Not making plots because both show_figs and save_figs were set to False."
                    print(msg)
                    log_msgs.append(msg)
                elif not save_figs:
                    msg = "Not saving plots because save_figs was set to False."
                    print(msg)
                    log_msgs.append(msg)

        msg = " *** Result of the test: " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result.append(test_result)

        # create fits file to hold the calculated flat for each slit
        if writefile:
            # this is the file to hold the image of the correction values
            outfile_ext = fits.ImageHDU(flatcor.reshape(wave_shape),
                                        name=slitlet_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of the comparison values
            complfile_ext = fits.ImageHDU(delf.reshape(delf_shape),
                                          name=slitlet_id)
            complfile.append(complfile_ext)

            # the file is not yet written, indicate that this slit was appended to list to be written
            msg = "Extension corresponding to slitlet " + slitlet_id + " appended to list to be written into calculated and comparison fits files."
            print(msg)
            log_msgs.append(msg)

    if writefile:
        outfile_name = step_input_filename.replace("flat_field.fits",
                                                   det + "_flat_calc.fits")
        complfile_name = step_input_filename.replace("flat_field.fits",
                                                     det + "_flat_comp.fits")

        # this is the file to hold the image of pipeline-calculated difference values
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        complfile.writeto(complfile_name, overwrite=True)

        msg = "\nFits file with calculated flat values of each slit saved as: "
        print(msg)
        print(outfile_name)
        log_msgs.append(msg)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline flat - calculated flat) saved as: "
        print(msg)
        print(complfile_name)
        log_msgs.append(msg)
        log_msgs.append(complfile_name)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = True
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
    if FINAL_TEST_RESULT:
        msg = "\n *** Final result for flat_field test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slitlets PASSED flat_field test."
    else:
        msg = "\n *** Final result for flat_field test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slitlets FAILED flat_field test."

    # end the timer
    flattest_end_time = time.time() - flattest_start_time
    if flattest_end_time > 60.0:
        flattest_end_time = flattest_end_time / 60.0  # in minutes
        flattest_tot_time = "* Script flattest_mos.py took ", repr(
            flattest_end_time) + " minutes to finish."
        if flattest_end_time > 60.0:
            flattest_end_time = flattest_end_time / 60.  # in hours
            flattest_tot_time = "* Script flattest_mos.py took ", repr(
                flattest_end_time) + " hours to finish."
    else:
        flattest_tot_time = "* Script flattest_mos.py took ", repr(
            flattest_end_time) + " seconds to finish."
    print(flattest_tot_time)
    log_msgs.append(flattest_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
def flattest(step_input_filename, dflatref_path=None, sfile_path=None, fflat_path=None, writefile=True,
             show_figs=True, save_figs=False, plot_name=None, threshold_diff=1.0e-7, debug=False):
    """
    This function calculates the difference between the pipeline and the calculated flat field values.
    The functions uses the output of the compute_world_coordinates.py script.

    Args:
        step_input_filename: str, name of the output fits file from the 2d_extract step (with full path)
        dflatref_path: str, path of where the D-flat reference fits files
        sfile_path: str, path of where the S-flat reference fits files
        fflat_path: str, path of where the F-flat reference fits files
        writefile: boolean, if True writes the fits files of the calculated flat and difference images
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots (the 3 plots can be saved or not independently with the function call)
        plot_name: string, desired name (if name is not given, the plot function will name the plot by
                    default)
        threshold_diff: float, threshold difference between pipeline output and ESA file
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable

    """

    log_msgs = []

    # start the timer
    flattest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)
    msg = "flat_field_file  -->     Grating:" + grat + "   Filter:" + filt + "   EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    # read in the on-the-fly flat image
    flatfile = step_input_filename.replace("flat_field.fits", "interpolatedflat.fits")
    # flatfile = step_input_filename.replace("flat_field.fits", "intflat.fits")  # for testing code only!

    # get the reference files
    # D-Flat
    dflat_ending = "f_01.03.fits"
    t = (dflatref_path, "nrs1", dflat_ending)
    dfile = "_".join(t)
    if det == "NRS2":
        dfile = dfile.replace("nrs1", "nrs2")
    msg = "Using D-flat: " + dfile
    print(msg)
    log_msgs.append(msg)
    dfim = fits.getdata(dfile, "SCI")
    dfimdq = fits.getdata(dfile, "DQ")
    # need to flip/rotate the image into science orientation
    ns = np.shape(dfim)
    dfim = np.transpose(dfim, (0, 2, 1))  # keep in mind that 0,1,2 = z,y,x in Python, whereas =x,y,z in IDL
    dfimdq = np.transpose(dfimdq)
    if det == "NRS2":
        # rotate science data by 180 degrees for NRS2
        dfim = dfim[..., ::-1, ::-1]
        dfimdq = dfimdq[..., ::-1, ::-1]
    naxis3 = fits.getval(dfile, "NAXIS3", 1)
    if debug:
        print('np.shape(dfim) =', np.shape(dfim))
        print('np.shape(dfimdq) =', np.shape(dfimdq))

    # get the wavelength values
    dfwave = np.array([])
    for i in range(naxis3):
        t = ("PFLAT", str(i + 1))
        keyword = "_".join(t)
        dfwave = np.append(dfwave, fits.getval(dfile, keyword, 1))
    dfrqe = fits.getdata(dfile, 2)

    # S-flat
    mode = "FS"
    if filt == "F070LP":
        flat = "FLAT4"
    elif filt == "F100LP":
        flat = "FLAT1"
    elif filt == "F170LP":
        flat = "FLAT2"
    elif filt == "F290LP":
        flat = "FLAT3"
    elif filt == "CLEAR":
        flat = "FLAT5"
    else:
        msg = "No filter correspondence. Exiting the program."
        print(msg)
        log_msgs.append(msg)
        result_msg = "Test skiped because there is no flat correspondance for the filter in the data: {}".format(filt)
        median_diff = "skip"
        return median_diff, result_msg, log_msgs

    sflat_ending = "f_01.01.fits"
    if mode in sfile_path:
        t = (sfile_path, grat, "OPAQUE", flat, "nrs1", sflat_ending)
        sfile = "_".join(t)
    else:
        msg = "Wrong path in for mode S-flat. This script handles mode " + mode + "only."
        print(msg)
        log_msgs.append(msg)
        # This is the key argument for the assert pytest function
        result_msg = "Wrong path in for mode S-flat. Test skiped because mode is not FS."
        median_diff = "skip"
        return median_diff, result_msg, log_msgs

    if debug:
        print("grat = ", grat)
        print("flat = ", flat)
        print("sfile used = ", sfile)

    if det == "NRS2":
        sfile = sfile.replace("nrs1", "nrs2")
    msg = "Using S-flat: " + sfile
    print(msg)
    log_msgs.append(msg)
    sfim = fits.getdata(sfile, "SCI")  # 1)
    sfimdq = fits.getdata(sfile, "DQ")  # 3)

    # need to flip/rotate image into science orientation
    sfim = np.transpose(sfim)
    sfimdq = np.transpose(sfimdq)
    if det == "NRS2":
        # rotate science data by 180 degrees for NRS2
        sfim = sfim[..., ::-1, ::-1]
        sfimdq = sfimdq[..., ::-1, ::-1]
    if debug:
        print("np.shape(sfim) = ", np.shape(sfim))
        print("np.shape(sfimdq) = ", np.shape(sfimdq))
        sf = fits.open(sfile)
        print(sf.info())
    try:
        sfv_a2001 = fits.getdata(sfile, "SLIT_A_200_1")
        sfv_a2002 = fits.getdata(sfile, "SLIT_A_200_2")
        sfv_a400 = fits.getdata(sfile, "SLIT_A_400")
        sfv_a1600 = fits.getdata(sfile, "SLIT_A_1600")
    except KeyError:
        print(" * S-Flat-Field file does not have extensions for slits 200A1, 200A2, 400A, or 1600A, trying with 200B")
    if det == "NRS2":
        sfv_b200 = fits.getdata(sfile, "SLIT_B_200")

    # F-Flat
    fflat_ending = "01.01.fits"
    if mode in fflat_path:
        ffile = "_".join((fflat_path, filt, fflat_ending))
    else:
        msg = "Wrong path in for mode F-flat. This script handles mode " + mode + "only."
        print(msg)
        log_msgs.append(msg)
        # This is the key argument for the assert pytest function
        median_diff = "skip"
        return median_diff, msg, log_msgs

    msg = "Using F-flat: " + ffile
    print(msg)
    log_msgs.append(msg)
    ffv = fits.getdata(ffile, 1)

    # now go through each pixel in the test data

    # get the datamodel from the assign_wcs output file
    extract2d_wcs_file = step_input_filename.replace("_flat_field.fits", "_extract_2d.fits")
    model = datamodels.MultiSlitModel(extract2d_wcs_file)

    if writefile:
        # create the fits list to hold the calculated flat values for each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create the fits list to hold the image of pipeline-calculated difference values
        hdu0 = fits.PrimaryHDU()
        complfile = fits.HDUList()
        complfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    # loop over the slits
    sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"]
    msg = "Now looping through the slits. This may take a while... "
    print(msg)
    log_msgs.append(msg)
    if det == "NRS2":
        sltname_list.append("S200B1")

    # but check if data is BOTS
    if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
        sltname_list = ["S1600A1"]

    # get all the science extensions
    sci_ext_list = auxfunc.get_sci_extensions(flatfile)

    # do the loop over the slits
    for slit_id in sltname_list:
        continue_flat_field_test = False
        if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
            slit = model
            continue_flat_field_test = True
        else:
            for slit_in_MultiSlitModel in model.slits:
                if slit_in_MultiSlitModel.name == slit_id:
                    slit = slit_in_MultiSlitModel
                    continue_flat_field_test = True
                    break

        if not continue_flat_field_test:
            continue
        else:
            # select the appropriate S-flat fast vector
            if slit_id == "S200A1":
                sfv = sfv_a2001
            if slit_id == "S200A2":
                sfv = sfv_a2002
            if slit_id == "S400A1":
                sfv = sfv_a400
            if slit_id == "S1600A1":
                sfv = sfv_a1600
            if slit_id == "S200B1":
                sfv = sfv_b200

            msg = "\nWorking with slit: " + slit_id
            print(msg)
            log_msgs.append(msg)
            try:
                ext = sci_ext_list[slit_id]  # this is for getting the science extension in the pipeline calculated flat
            except KeyError:
                # the extension does not exist in the file, look for the next slit name
                #ext = sltname_list.index(slit_id)
                continue

            # get the wavelength
            # slit.x(y)start are 1-based, turn them to 0-based for extraction
            # xstart, xend = slit.xstart - 1, slit.xstart -1 + slit.xsize
            # ystart, yend = slit.ystart - 1, slit.ystart -1 + slit.ysize
            # y, x = np.mgrid[ystart: yend, xstart: xend]
            x, y = wcstools.grid_from_bounding_box(slit.meta.wcs.bounding_box, step=(1, 1), center=True)
            ra, dec, wave = slit.meta.wcs(x, y)  # wave is in microns
            # detector2slit = slit.meta.wcs.get_transform('detector', 'slit_frame')
            # sx, sy, ls = detector2slit(x, y)
            # world_coordinates = np.array([wave, ra, dec, sy])#, x, y])

            # msg = "wavelength array: "+repr(wave[19, 483])
            # print(msg)
            # log_msgs.append(msg)

            # get the subwindow origin
            px0 = slit.xstart - 1 + model.meta.subarray.xstart
            py0 = slit.ystart - 1 + model.meta.subarray.ystart
            msg = " Subwindow origin:   px0=" + repr(px0) + "   py0=" + repr(py0)
            print(msg)
            log_msgs.append(msg)
            n_p = np.shape(wave)
            nw = n_p[0] * n_p[1]
            nw1, nw2 = n_p[1], n_p[0]  # remember that x=nw1 and y=nw2 are reversed  in Python
            if debug:
                print(" nw1, nw2, nw = ", nw1, nw2, nw)

            delf = np.zeros([nw2, nw1]) + 999.0
            flatcor = np.zeros([nw2, nw1]) + 999.0

            # read the pipeline-calculated flat image
            # there are four extensions in the flatfile: SCI, DQ, ERR, WAVELENGTH
            pipeflat = fits.getdata(flatfile, ext)

            # make sure the two arrays are the same shape
            if np.shape(flatcor) != np.shape(pipeflat):
                msg1 = 'WARNING -> Something went wrong, arrays are not the same shape:'
                msg2 = 'np.shape(flatcor) = ' + repr(np.shape(flatcor)) + '   np.shape(pipeflat) = ' + repr(
                    np.shape(pipeflat))
                msg3 = 'Mathematical operations will fail. Exiting the loop here and setting test as FAILED.'
                print(msg1)
                print(msg2)
                print(msg3)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
                log_msgs.append(msg3)
                msg = " *** Result of the test: " + test_result + "\n"
                '''
                msg = 'Forcing arrays to be the same length.'
                n_p = np.shape(pipeflat)
                delf = np.zeros(n_p) + 999.0
                flatcor = np.zeros(n_p) + 999.0
                '''
                print(msg)
                log_msgs.append(msg)
                test_result = "FAILED"
                total_test_result.append(test_result)
                continue

            # loop through the wavelengths
            msg = " Looping through the wavelengths... "
            print(msg)
            log_msgs.append(msg)
            for j in range(nw1):  # in x
                for k in range(nw2):  # in y
                    if np.isfinite(wave[k, j]):  # skip if wavelength is NaN
                        # get thr full-frame pixel indeces for D- and S-flat image components
                        pind = [k + py0 - 1, j + px0 - 1]

                        # get the pixel bandwidth
                        if (j != 0) and (j < nw1 - 1):
                            if np.isfinite(wave[k, j + 1]) and np.isfinite(wave[k, j - 1]):
                                delw = 0.5 * (wave[k, j + 1] - wave[k, j - 1])
                            if np.isfinite(wave[k, j + 1]) and not np.isfinite(wave[k, j - 1]):
                                delw = wave[k, j + 1] - wave[k, j]
                            if not np.isfinite(wave[k, j + 1]) and np.isfinite(wave[k, j - 1]):
                                delw = wave[k, j] - wave[k, j - 1]
                        if j == 0:
                            delw = wave[k, j + 1] - wave[k, j]
                        if j == nw - 1:
                            delw = wave[k, j] - wave[k, j - 1]

                        # integrate over D-flat fast vector
                        dfrqe_wav = dfrqe.field("WAVELENGTH")
                        dfrqe_rqe = dfrqe.field("RQE")
                        iw = np.where((dfrqe_wav >= wave[k, j] - delw / 2.) & (dfrqe_wav <= wave[k, j] + delw / 2.))
                        int_tab = auxfunc.idl_tabulate(dfrqe_wav[iw], dfrqe_rqe[iw])
                        first_dfrqe_wav, last_dfrqe_wav = dfrqe_wav[iw[0]][0], dfrqe_wav[iw[0]][-1]
                        dff = int_tab / (last_dfrqe_wav - first_dfrqe_wav)

                        if debug:
                            print("np.shape(dfrqe_wav) : ", np.shape(dfrqe_wav))
                            print("np.shape(dfrqe_rqe) : ", np.shape(dfrqe_rqe))
                            print("dfimdq[pind[0],[pind[1]] : ", dfimdq[pind[0], pind[1]])
                            print("np.shape(iw) =", np.shape(iw))
                            print("np.shape(dfrqe_wav) = ", np.shape(dfrqe_wav[iw]))
                            print("np.shape(dfrqe_rqe) = ", np.shape(dfrqe_rqe[iw]))
                            print("int_tab=", int_tab)
                            print("np.shape(dfim) = ", np.shape(dfim))
                            print("dff = ", dff)

                        # interpolate over D-flat cube
                        iloc = auxfunc.idl_valuelocate(dfwave, wave[k, j])[0]
                        if dfwave[iloc] > wave[k, j]:
                            iloc -= 1
                        ibr = [iloc]
                        if iloc != len(dfwave) - 1:
                            ibr.append(iloc + 1)
                        # get the values in the z-array at indeces ibr, and x=pind[1] and y=pind[0]
                        zz = dfim[:, pind[0], pind[1]][ibr]
                        # now determine the length of the array with only the finite numbers
                        zzwherenonan = np.where(np.isfinite(zz))
                        kk = np.size(zzwherenonan)
                        dfs = 1.0
                        if (wave[k, j] <= max(dfwave)) and (wave[k, j] >= min(dfwave)) and (kk == 2):
                            dfs = np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan])

                        # check DQ flags
                        if dfimdq[pind[0]][pind[1]] != 0:
                            dfs = 1.0

                        if debug:
                            print("wave[k, j] = ", wave[k, j])
                            print("iloc = ", iloc)
                            print("ibr = ", ibr)
                            print("np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan]) = ",
                                  np.interp(wave[k, j], dfwave[ibr], zz[zzwherenonan]))
                            print("dfs = ", dfs)

                        # integrate over S-flat fast vector
                        sfv_wav = sfv.field("WAVELENGTH")
                        sfv_dat = sfv.field("DATA")
                        iw = np.where((sfv_wav >= wave[k, j] - delw / 2.0) & (sfv_wav <= wave[k, j] + delw / 2.0))
                        sff = 1.0
                        if np.size(iw) > 2:
                            int_tab = auxfunc.idl_tabulate(sfv_wav[iw], sfv_dat[iw])
                            first_sfv_wav, last_sfv_wav = sfv_wav[iw[0]][0], sfv_wav[iw[0]][-1]
                            sff = int_tab / (last_sfv_wav - first_sfv_wav)
                        # get s-flat pixel-dependent correction
                        sfs = 1.0
                        if sfimdq[pind[0], pind[1]] == 0:
                            sfs = sfim[pind[0], pind[1]]

                        if debug:
                            print("np.shape(iw) =", np.shape(iw))
                            print("np.shape(sfv_wav) = ", np.shape(sfv_wav))
                            print("np.shape(sfv_dat) = ", np.shape(sfv_dat))
                            print("int_tab = ", int_tab)
                            print("sff = ", sff)
                            print("sfs = ", sfs)

                        # integrate over F-flat fast vector
                        # reference file blue cutoff is 1 micron, so need to force solution for shorter wavs
                        ffv_wav = ffv.field("WAVELENGTH")
                        ffv_dat = ffv.field("DATA")
                        fff = 1.0
                        if wave[k, j] - delw / 2.0 >= 1.0:
                            iw = np.where((ffv_wav >= wave[k, j] - delw / 2.0) & (ffv_wav <= wave[k, j] + delw / 2.0))
                            if np.size(iw) > 1:
                                int_tab = auxfunc.idl_tabulate(ffv_wav[iw], ffv_dat[iw])
                                first_ffv_wav, last_ffv_wav = ffv_wav[iw[0]][0], ffv_wav[iw[0]][-1]
                                fff = int_tab / (last_ffv_wav - first_ffv_wav)

                        flatcor[k, j] = dff * dfs * sff * sfs * fff

                        if debug:
                            print("np.shape(iw) =", np.shape(iw))
                            print("np.shape(ffv_wav) = ", np.shape(ffv_wav))
                            print("np.shape(ffv_dat) = ", np.shape(ffv_dat))
                            print("fff = ", fff)
                            print("flatcor[k, j] = ", flatcor[k, j])
                            print("dff, dfs, sff, sfs, fff:", dff, dfs, sff, sfs, fff)

                        try:
                            # Difference between pipeline and calculated values
                            delf[k, j] = pipeflat[k, j] - flatcor[k, j]

                            if debug:
                                print("delf[k, j] = ", delf[k, j])

                            # Remove all pixels with values=1 (outside slit boundaries) for statistics
                            if pipeflat[k, j] == 1:
                                delf[k, j] = 999.0
                            if np.isnan(wave[k, j]):
                                flatcor[k, j] = 1.0  # no correction if no wavelength

                            if debug:
                                print("flatcor[k, j] = ", flatcor[k, j])
                                print("delf[k, j] = ", delf[k, j])
                        except:
                            IndexError

            if debug:
                no_999 = delf[np.where(delf != 999.0)]
                print("np.shape(no_999) = ", np.shape(no_999))
                alldelf = delf.flatten()
                print("median of the whole array: ", np.median(alldelf))
                print("median, stdev in delf: ", np.median(no_999), np.std(no_999))
                neg_vals = no_999[np.where(no_999 < 0.0)]
                print("neg_vals = ", np.shape(neg_vals))
                print("np.shape(delf) = ", np.shape(delf))
                print("np.shape(delfg) = ", np.shape(delfg))

            nanind = np.isnan(delf)  # get all the nan indexes
            notnan = ~nanind  # get all the not-nan indexes
            delf = delf[notnan]  # get rid of NaNs
            if delf.size == 0:
                msg1 = " * Unable to calculate statistics because difference array has all values as NaN. Test will be set to FAILED."
                print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
                delfg_mean, delfg_median, delfg_std = np.nan, np.nan, np.nan
                stats = [delfg_mean, delfg_median, delfg_std]
            else:
                msg = "Calculating statistics... "
                print(msg)
                log_msgs.append(msg)
                delfg = delf[np.where((delf != 999.0) & (delf < 0.1) & (delf > -0.1))]  # ignore outliers
                if delfg.size == 0:
                    msg1 = " * Unable to calculate statistics because difference array has all outlier values. Test will be set to FAILED."
                    print(msg1)
                    log_msgs.append(msg1)
                    test_result = "FAILED"
                    delfg_mean, delfg_median, delfg_std = np.nan, np.nan, np.nan
                    stats = [delfg_mean, delfg_median, delfg_std]
                else:
                    stats_and_strings = auxfunc.print_stats(delfg, "Flat Difference", float(threshold_diff), abs=True)
                    stats, stats_print_strings = stats_and_strings
                    delfg_mean, delfg_median, delfg_std = stats
                    for msg in stats_print_strings:
                        log_msgs.append(msg)

                    # This is the key argument for the assert pytest function
                    median_diff = False
                    if abs(delfg_median) <= float(threshold_diff):
                        median_diff = True
                    if median_diff:
                        test_result = "PASSED"
                    else:
                        test_result = "FAILED"

            msg = " *** Result of the test: " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result.append(test_result)

            # make histogram
            if show_figs or save_figs:

                # set plot variables
                main_title = filt + "   " + grat + "   SLIT=" + slit_id + "\n"
                bins = None  # binning for the histograms, if None the function will select them automatically
                #             lolim_x, uplim_x, lolim_y, uplim_y
                plt_origin = None

                # Residuals img and histogram
                title = main_title + "Residuals"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = "flat$_{pipe}$ - flat$_{calc}$", "N"
                info_hist = [xlabel, ylabel, bins, stats]
                if delfg.size != 0 and delfg[1] is np.nan:
                    msg = "Unable to create plot of relative wavelength difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    file_path = step_input_filename.replace(os.path.basename(step_input_filename), "")
                    file_basename = os.path.basename(step_input_filename.replace(".fits", ""))
                    t = (file_basename, "FS_flattest_" + slit_id + "_histogram.pdf")
                    plt_name = "_".join(t)
                    plt_name = os.path.join(file_path, plt_name)
                    difference_img = (pipeflat - flatcor)  # /flatcor
                    in_slit = np.logical_and(difference_img < 900.0,
                                             difference_img > -900.0)  # ignore points out of the slit,
                    difference_img[~in_slit] = np.nan  # Set values outside the slit to NaN
                    # nanind = np.isnan(difference_img)   # get all the nan indexes
                    # difference_img[nanind] = np.nan   # set all nan indexes to have a value of nan
                    vminmax = [-5 * delfg_std,
                               5 * delfg_std]  # set the range of values to be shown in the image, will affect color scale
                    auxfunc.plt_two_2Dimgandhist(difference_img, delfg, info_img, info_hist, plt_name=plt_name,
                                                 vminmax=vminmax,
                                                 plt_origin=plt_origin, show_figs=show_figs, save_figs=save_figs)

            elif not save_figs and not show_figs:
                msg = "Not making plots because both show_figs and save_figs were set to False."
                print(msg)
                log_msgs.append(msg)
            elif not save_figs:
                msg = "Not saving plots because save_figs was set to False."
                print(msg)
                log_msgs.append(msg)

            # create fits file to hold the calculated flat for each slit
            if writefile:
                msg = "Saving the fits files with the calculated flat for each slit..."
                print(msg)
                log_msgs.append(msg)

                # this is the file to hold the image of pipeline-calculated difference values
                outfile_ext = fits.ImageHDU(flatcor, name=slit_id)
                outfile.append(outfile_ext)

                # this is the file to hold the image of pipeline-calculated difference values
                complfile_ext = fits.ImageHDU(delf, name=slit_id)
                complfile.append(complfile_ext)

                # the file is not yet written, indicate that this slit was appended to list to be written
                msg = "Extension " + repr(
                    i) + " appended to list to be written into calculated and comparison fits files."
                print(msg)
                log_msgs.append(msg)

    if writefile:
        outfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_calc.fits")
        complfile_name = step_input_filename.replace("flat_field.fits", det + "_flat_comp.fits")

        # create the fits list to hold the calculated flat values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        complfile.writeto(complfile_name, overwrite=True)

        msg = "\nFits file with calculated flat values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline flat - calculated flat) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(complfile_name)
        log_msgs.append(complfile_name)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final result for flat_field test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED flat_field test."
    else:
        msg = "\n *** Final result for flat_field test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED flat_field test."

    # end the timer
    flattest_end_time = time.time() - flattest_start_time
    if flattest_end_time > 60.0:
        flattest_end_time = flattest_end_time / 60.0  # in minutes
        flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " minutes to finish."
        if flattest_end_time > 60.0:
            flattest_end_time = flattest_end_time / 60.  # in hours
            flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " hours to finish."
    else:
        flattest_tot_time = "* Script flattest_fs.py took ", repr(flattest_end_time) + " seconds to finish."
    print(flattest_tot_time)
    log_msgs.append(flattest_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
예제 #14
0
def contam_corr(input_model, waverange, photom, max_cores):
    """
    The main WFSS contamination correction function

    Parameters
    ----------
    input_model : `~jwst.datamodels.MultiSlitModel`
        Input data model containing 2D spectral cutouts
    waverange : `~jwst.datamodels.WavelengthrangeModel`
        Wavelength range reference file model
    photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel`
        Photom (flux cal) reference file model
    max_cores : string
        Number of cores to use for multiprocessing. If set to 'none'
        (the default), then no multiprocessing will be done. The other
        allowable values are 'quarter', 'half', and 'all', which indicate
        the fraction of cores to use for multi-proc. The total number of
        cores includes the SMT cores (Hyper Threading for Intel).

    Returns
    -------
    output_model : `~jwst.datamodels.MultiSlitModel`
        A copy of the input_model that has been corrected
    simul_model : `~jwst.datamodels.ImageModel`
        Full-frame simulated image of the grism exposure
    contam_model : `~jwst.datamodels.MultiSlitModel`
        Contamination estimate images for each source slit

    """
    # Determine number of cpu's to use for multi-processing
    if max_cores == 'none':
        ncpus = 1
    else:
        num_cores = multiprocessing.cpu_count()
        if max_cores == 'quarter':
            ncpus = num_cores // 4 or 1
        elif max_cores == 'half':
            ncpus = num_cores // 2 or 1
        elif max_cores == 'all':
            ncpus = num_cores
        else:
            ncpus = 1
        log.debug(f"Found {num_cores} cores; using {ncpus}")

    # Initialize output model
    output_model = input_model.copy()

    # Get the segmentation map for this grism exposure
    seg_model = datamodels.open(input_model.meta.segmentation_map)

    # Get the direct image from which the segmentation map was constructed
    direct_file = input_model.meta.direct_image
    image_names = [direct_file]
    log.debug(f"Direct image names={image_names}")

    # Get the grism WCS from the input model
    grism_wcs = input_model.slits[0].meta.wcs

    # Find out how many spectral orders are defined, based on the
    # array of order values in the Wavelengthrange ref file
    spec_orders = np.asarray(waverange.order)
    spec_orders = spec_orders[spec_orders != 0]  # ignore any order 0 entries
    log.debug(f"Spectral orders defined = {spec_orders}")

    # Get the FILTER and PUPIL wheel positions, for use later
    filter_kwd = input_model.meta.instrument.filter
    pupil_kwd = input_model.meta.instrument.pupil

    # NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel
    # with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just
    # the opposite. It has gratings in the FILTER wheel and filters in the
    # PUPIL wheel. So when processing NIRISS grism exposures the name of
    # filter needs to come from the PUPIL keyword value.
    if input_model.meta.instrument.name == 'NIRISS':
        filter_name = pupil_kwd
    else:
        filter_name = filter_kwd

    # Load lists of wavelength ranges and flux cal info for all orders
    wmin = {}
    wmax = {}
    sens_waves = {}
    sens_response = {}
    for order in spec_orders:
        wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order])
        wmin[order] = wavelength_range[order][0]
        wmax[order] = wavelength_range[order][1]
        # Load the sensitivity (inverse flux cal) data for this mode and order
        sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order)
    log.debug(f"wmin={wmin}, wmax={wmax}")

    # Initialize the simulated image object
    simul_all = None
    obs = Observation(image_names, seg_model, grism_wcs, filter_name,
                      boundaries=[0, 2047, 0, 2047], max_cpu=ncpus)

    # Create simulated grism image for each order and sum them up
    for order in spec_orders:

        log.info(f"Creating full simulated grism image for order {order}")
        obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order],
                         sens_response[order])

        # Accumulate result for this order into the combined image
        if simul_all is None:
            simul_all = obs.simulated_image
        else:
            simul_all += obs.simulated_image

    # Save the full-frame simulated grism image
    simul_model = datamodels.ImageModel(data=simul_all)
    simul_model.update(input_model, only="PRIMARY")

    # Loop over all slits/sources to subtract contaminating spectra
    log.info("Creating contamination image for each individual source")
    contam_model = datamodels.MultiSlitModel()
    contam_model.update(input_model)
    slits = []
    for slit in output_model.slits:

        # Create simulated spectrum for this source only
        sid = slit.source_id
        order = slit.meta.wcsinfo.spectral_order
        chunk = np.where(obs.IDs == sid)[0][0]  # find chunk for this source

        obs.simulated_image = np.zeros(obs.dims)
        obs.disperse_chunk(chunk, order, wmin[order], wmax[order],
                           sens_waves[order], sens_response[order])
        this_source = obs.simulated_image

        # Contamination estimate is full simulated image minus this source
        contam = simul_all - this_source

        # Create a cutout of the contam image that matches the extent
        # of the source slit
        x1 = slit.xstart - 1
        x2 = x1 + slit.xsize
        y1 = slit.ystart - 1
        y2 = y1 + slit.ysize
        cutout = contam[y1:y2, x1:x2]
        new_slit = datamodels.SlitModel(data=cutout)
        copy_slit_info(slit, new_slit)
        slits.append(new_slit)

        # Subtract the cutout from the source slit
        slit.data -= cutout

    # Save the contamination estimates for all slits
    contam_model.slits.extend(slits)

    # Set the step status to COMPLETE
    output_model.meta.cal_step.wfss_contam = 'COMPLETE'

    return output_model, simul_model, contam_model
def pathtest(step_input_filename,
             reffile,
             comparison_filename,
             writefile=True,
             show_figs=False,
             save_figs=True,
             threshold_diff=1.0e-7,
             debug=False):
    """
    This function calculates the difference between the pipeline and
    calculated pathloss values.
    Args:
        step_input_filename: str, full path name of sourcetype output fits file
        reffile: str, path to the pathloss FS reference fits file
        comparison_filename: str, path to pipeline-generated pathloss fits file
        writefile: boolean, if True writes the fits files of
                   calculated flat and difference images
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, whether to save the plots or not
        plot_name: string, desired name. If not given, plot has default name
        threshold_diff: float, threshold difference between
                        pipeline output and comparison file
        debug: boolean, if true print statements will show on-screen
    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable
    """

    log_msgs = []

    # start the timer
    pathtest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)

    msg = "pathloss file:  Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    is_point_source = False

    # get the datamodel from the assign_wcs output file
    extract2d_wcs_file = step_input_filename.replace("_srctype.fits",
                                                     "_extract_2d.fits")
    model = datamodels.MultiSlitModel(extract2d_wcs_file)

    if writefile:
        # create the fits list to hold the calculated pathloss values for each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create fits list to hold pipeline-calculated difference values
        hdu0 = fits.PrimaryHDU()
        compfile = fits.HDUList()
        compfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    print('Checking files exist & obtaining datamodels, takes a few mins...')
    if os.path.isfile(comparison_filename):
        if debug:
            print('Comparison file does exist.')
    else:
        result_msg = 'Comparison file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the comparison data model
    pathloss_pipe = datamodels.open(comparison_filename)
    # For the moment, the pipeline is using the wrong reference file for slit 400A1, so read file that
    # re-processed with the right reference file and open corresponding data model
    if os.path.isfile(
            step_input_filename.replace("srctype.fits",
                                        "pathloss_400A1.fits")):
        pathloss_400a1 = step_input_filename.replace("srctype.fits",
                                                     "pathloss_400A1.fits")
        pathloss_pipe_400a1 = datamodels.open(pathloss_400a1)
    if debug:
        print('got comparison datamodel!')

    if os.path.isfile(step_input_filename):
        if debug:
            print('Input file does exist.')
    else:
        result_msg = 'Input file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the input data model
    pl = datamodels.open(step_input_filename)
    if debug:
        print('got input datamodel!')

        msg = "Now looping through the slits. This may take a while... "
        print(msg)
        log_msgs.append(msg)

    sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"]
    if det == "NRS2":
        sltname_list.append("S200B1")

    # but check if data is BOTS
    if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
        sltname_list = ["S1600A1"]

    # get all the science extensions
    ps_uni_ext_list = get_ps_uni_extensions(reffile, is_point_source)

    slit_val = 0
    for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits):
        slit_val = slit_val + 1
        slit_id = slit.name
        #if slit_id == 'S400A1':
        #    continue
        continue_pathloss_test = False
        if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
            slit = model
            continue_pathloss_test = True
        else:
            for slit_in_MultiSlitModel in model.slits:
                if slit_in_MultiSlitModel.name == slit_id:
                    slit = slit_in_MultiSlitModel
                    continue_pathloss_test = True
                    break

        if not continue_pathloss_test:
            continue
        else:
            try:
                if is_point_source is True:
                    ext = ps_uni_ext_list[0][slit_id]
                    print("Retrieved point source extension")
                elif is_point_source is False:
                    ext = ps_uni_ext_list[1][slit_id]
                    print("Retrieved extended source extension for {}".format(
                        slit_val))
            except KeyError:
                # gets index associted with slit if issue above
                ext = sltname_list.index(slit_id)
                print("Unable to retrieve extension.")

        wcs_obj = slit.meta.wcs

        # get the wavelength
        x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, wave = wcs_obj(x, y)
        wave_sci = wave * 10**(-6)  # microns --> meters

        # adjustments for S400A1
        if slit_id == "S400A1":
            if is_point_source:
                ext = 1
            else:
                ext = 3
                print(
                    "Got uniform source extension frome extra reference file")
            reffile2use = "jwst-nirspec-a400.plrf.fits"
        else:
            reffile2use = reffile

        print("Using reference file {}".format(reffile2use))
        plcor_ref_ext = fits.getdata(reffile2use, ext)
        hdul = fits.open(reffile2use)

        plcor_ref = hdul[1].data
        w = wcs.WCS(hdul[1].header)

        w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref.
                              shape[1], :plcor_ref.shape[2]]
        slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0)

        previous_sci = slit.data
        if slit_id == 'S400A1':
            if pathloss_pipe_400a1 is not None:
                for pipe_slit_400a1 in pathloss_pipe_400a1.slits:
                    if pipe_slit_400a1.name == "S400A1":
                        comp_sci = pipe_slit_400a1.data
                        pipe_correction = pipe_slit_400a1.pathloss
                        break
                    else:
                        continue
        else:
            comp_sci = pipe_slit.data
            pipe_correction = pipe_slit.pathloss
        if len(pipe_correction) == 0:
            print(
                "Pipeline pathloss correction in datamodel is empty. Skipping testing this slit."
            )
            continue

        # set up generals for all the plots
        font = {'weight': 'normal', 'size': 7}
        matplotlib.rc('font', **font)

        corr_vals = np.interp(wave_sci, wave_ref[:, 0, 0], plcor_ref_ext)
        corrected_array = previous_sci / corr_vals

        # Plots:
        step_input_filepath = step_input_filename.replace(".fits", "")
        # my correction values
        fig = plt.figure()
        plt.subplot(321)
        norm = ImageNormalize(corr_vals)
        plt.imshow(corr_vals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Calculated Correction')
        plt.colorbar()
        # pipe corerction
        plt.subplot(322)
        norm = ImageNormalize(pipe_correction)
        plt.imshow(pipe_correction,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Pathloss Correction Comparison')
        plt.colorbar()
        # residuals (pipe correction - my correction)
        corr_residuals = pipe_correction - corr_vals
        plt.subplot(323)
        norm = ImageNormalize(corr_residuals)
        plt.imshow(corr_residuals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Correction residuals')
        plt.colorbar()
        # pipe science data before
        plt.subplot(324)
        norm = ImageNormalize(previous_sci)
        plt.imshow(previous_sci,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Normalized pipeline science data before pathloss')
        plt.colorbar()
        # pipe science data after
        plt.subplot(325)
        norm = ImageNormalize(comp_sci)
        plt.imshow(comp_sci,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Normalized pipeline science data after pathloss')
        plt.colorbar()
        # pipe science data after pathloss
        plt.subplot(326)
        norm = ImageNormalize(corrected_array)
        plt.imshow(corrected_array,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('My science data after pathloss')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        fig.suptitle(
            "FS UNI Pathloss Correction Testing for {}".format(slit_id))

        # add space between the subplots
        fig.subplots_adjust(wspace=0.9)

        # Show and/or save figures
        if show_figs:
            plt.show()
        if save_figs:
            plt_name = step_input_filepath + "_Pathloss_test_slit_" + str(
                slit_id) + "_FS_extended.png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        plt.clf()

        # create fits file to hold the calculated pathloss for each slit
        if writefile:
            msg = "Saving the fits files with the calculated pathloss for each slit..."
            print(msg)
            log_msgs.append(msg)

            # this is the file to hold the image of pipeline-calculated difference values
            outfile_ext = fits.ImageHDU(corr_vals, name=slit_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values
            compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id)
            compfile.append(compfile_ext)

        # Histogram
        ax = plt.subplot(212)
        plt.hist(corr_residuals[~np.isnan(corr_residuals)],
                 bins=100,
                 range=(-0.00000013, 0.00000013))
        plt.title('Residuals Histogram')
        plt.xlabel("Correction Value")
        plt.ylabel("Number of Occurences")
        nanind = np.isnan(corr_residuals)  # get all the nan indexes
        notnan = ~nanind  # get all the not-nan indexes
        arr_mean = np.mean(corr_residuals[notnan])
        arr_median = np.median(corr_residuals[notnan])
        arr_stddev = np.std(corr_residuals[notnan])
        plt.axvline(arr_mean, label="mean = %0.3e" % (arr_mean), color="g")
        plt.axvline(arr_median,
                    label="median = %0.3e" % (arr_median),
                    linestyle="-.",
                    color="b")
        str_arr_stddev = "stddev = {:0.3e}".format(arr_stddev)
        ax.text(0.73, 0.67, str_arr_stddev, transform=ax.transAxes, fontsize=7)
        plt.legend()
        plt.minorticks_on()

        # Show and/or save figures
        if save_figs:
            plt_name = step_input_filename.replace(
                ".fits", "") + "_Pathlosstest_slitlet_" + slit_id + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)

        plt.close()

        if corr_residuals[~np.isnan(corr_residuals)].size == 0:
            msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \
                   "Test will be set to FAILED."
            print(msg1)
            log_msgs.append(msg1)
            test_result = "FAILED"
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            corr_residuals = corr_residuals[
                np.where((corr_residuals != 999.0) & (corr_residuals < 0.1)
                         & (corr_residuals > -0.1))]  # ignore outliers
            if corr_residuals.size == 0:
                msg1 = " * Unable to calculate statistics because difference array has all outlier values. Test " \
                       "will be set to FAILED."
                if debug:
                    print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
            else:
                stats_and_strings = auxfunc.print_stats(corr_residuals,
                                                        "Difference",
                                                        float(threshold_diff),
                                                        absolute=True)
                stats, stats_print_strings = stats_and_strings
                corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats
                for msg in stats_print_strings:
                    log_msgs.append(msg)

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(corr_residuals_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

                msg = " *** Result of the test: " + test_result + "\n"
                if debug:
                    print(msg)
                log_msgs.append(msg)
                total_test_result.append(test_result)

    if writefile:
        outfile_name = step_input_filename.replace("srctype",
                                                   det + "_calcuated_pathloss")
        compfile_name = step_input_filename.replace(
            "srctype", det + "_comparison_pathloss")

        # create the fits list to hold the calculated pathloss values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        compfile.writeto(compfile_name, overwrite=True)

        msg = "\nFits file with calculated pathloss values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(compfile_name)
        log_msgs.append(compfile_name)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final result for path_loss test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED path_loss test."
    else:
        msg = "\n *** Final result for path_loss test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED path_loss test."

    # end the timer
    pathloss_end_time = time.time() - pathtest_start_time
    if pathloss_end_time > 60.0:
        pathloss_end_time = pathloss_end_time / 60.0  # in minutes
        pathloss_tot_time = "* Script FS_UNI.py took ", repr(
            pathloss_end_time) + " minutes to finish."
        if pathloss_end_time > 60.0:
            pathloss_end_time = pathloss_end_time / 60.  # in hours
            pathloss_tot_time = "* Script FS_UNI.py took ", repr(
                pathloss_end_time) + " hours to finish."
    else:
        pathloss_tot_time = "* Script FS_UNI.py took ", repr(
            pathloss_end_time) + " seconds to finish."
    print(pathloss_tot_time)
    log_msgs.append(pathloss_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs