Exemple #1
0
def test_nirspec_nrs1_wcs(_bigdata):
    """

    Regression test of creating a WCS object and doing pixel to sky transformation.

    """
    output_file_base, output_file = add_suffix('nrs1_ifu_wcs_output.fits', 'assignwcsstep')

    try:
        os.remove(output_file)
    except:
        pass

    input_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-ifu', 'jw00011001001_01120_00001_NRS1_rate.fits')
    ref_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-ifu', 'jw00011001001_01120_00001_NRS1_rate_assign_wcs.fits')

    AssignWcsStep.call(input_file,
           output_file=output_file_base, name='assignwcsstep'
    )
    im = ImageModel(output_file)
    imref = ImageModel(ref_file)
    a_wcs = nirspec.nrs_wcs_set_input(im, 0)
    w = a_wcs
    y, x = np.mgrid[w.bounding_box[1][0]:w.bounding_box[1][1], w.bounding_box[0][0]: w.bounding_box[0][1]]
    ra, dec, lam = w(x, y)
    a_wcs_ref = nirspec.nrs_wcs_set_input(im, 0)
    wref = a_wcs_ref
    raref, decref, lamref = wref(x, y)

    # equal_nan is used here as many of the entries are nan.
    # The domain is defined but it is only a few entries in there that are valid
    # as it is a curved narrow slit.
    utils.assert_allclose(ra, raref, equal_nan=True)
    utils.assert_allclose(dec, decref, equal_nan=True)
    utils.assert_allclose(lam, lamref, equal_nan=True)
    def test_nirspec_nrs1_wcs(self):
        """

        Regression test of creating a WCS object and doing pixel to sky transformation.

        """
        input_file = self.get_data(*self.test_dir,
                                  'jw00023001001_01101_00001_NRS1_ramp_fit.fits')
        ref_file = self.get_data(*self.ref_loc,
                                 'jw00023001001_01101_00001_NRS1_ramp_fit_assign_wcs.fits')

        result = AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')
        result.close()

        im = ImageModel(result.meta.filename)
        imref = ImageModel(ref_file)

        for slit in ['S200A1', 'S200A2', 'S400A1', 'S1600A1']:
            w = nirspec.nrs_wcs_set_input(im, slit)
            grid = grid_from_bounding_box(w.bounding_box)
            ra, dec, lam = w(*grid)
            wref = nirspec.nrs_wcs_set_input(imref, slit)
            raref, decref, lamref = wref(*grid)

            assert_allclose(ra, raref, equal_nan=True)
            assert_allclose(dec, decref, equal_nan=True)
            assert_allclose(lam, lamref, equal_nan=True)
Exemple #3
0
    def test_nirspec_nrs1_wcs(self):
        """

        Regression test of creating a WCS object and doing pixel to sky transformation.

        """
        input_file = self.get_data(*self.test_dir,
                                  'jw00023001001_01101_00001_NRS1_ramp_fit.fits')
        ref_file = self.get_data(*self.ref_loc,
                                 'jw00023001001_01101_00001_NRS1_ramp_fit_assign_wcs.fits')

        result = AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')
        result.close()

        im = ImageModel(result.meta.filename)
        imref = ImageModel(ref_file)

        for slit in ['S200A1', 'S200A2', 'S400A1', 'S1600A1']:
            w = nirspec.nrs_wcs_set_input(im, slit)
            grid = grid_from_bounding_box(w.bounding_box)
            ra, dec, lam = w(*grid)
            wref = nirspec.nrs_wcs_set_input(imref, slit)
            raref, decref, lamref = wref(*grid)

            assert_allclose(ra, raref, equal_nan=True)
            assert_allclose(dec, decref, equal_nan=True)
            assert_allclose(lam, lamref, equal_nan=True)
def test_nirspec_ifu_wcs(_bigdata, test_id, input_file, truth_file):
    """
    Regression test of creating a WCS object and doing pixel to sky transformation.
    """
    del test_id

    input_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-ifu',
                              input_file)
    truth_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-ifu',
                              truth_file)

    result = AssignWcsStep.call(input_file,
                                save_results=True,
                                suffix='assign_wcs')
    result.close()

    im = ImageModel(result.meta.filename)
    imref = ImageModel(truth_file)
    w = nirspec.nrs_wcs_set_input(im, 0)
    grid = grid_from_bounding_box(w.bounding_box)
    ra, dec, lam = w(*grid)
    wref = nirspec.nrs_wcs_set_input(imref, 0)
    raref, decref, lamref = wref(*grid)

    # equal_nan is used here as many of the entries are nan.
    # The domain is defined but it is only a few entries in there that are valid
    # as it is a curved narrow slit.
    assert_allclose(ra, raref, equal_nan=True)
    assert_allclose(dec, decref, equal_nan=True)
    assert_allclose(lam, lamref, equal_nan=True)
Exemple #5
0
def test_nirspec_ifu_wcs(envopt, _jail, test_id, input_file, truth_file):
    """
    Regression test of creating a WCS object and doing pixel to sky transformation.
    """
    del test_id

    input_file = get_bigdata('jwst-pipeline', envopt,
                             'nirspec', 'test_wcs', 'nrs1-ifu', input_file)
    truth_file = get_bigdata('jwst-pipeline', envopt,
                             'nirspec', 'test_wcs', 'nrs1-ifu', 'truth', truth_file)

    result = AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')
    result.close()

    im = ImageModel(result.meta.filename)
    imref = ImageModel(truth_file)
    w = nirspec.nrs_wcs_set_input(im, 0)
    grid = grid_from_bounding_box(w.bounding_box)
    ra, dec, lam = w(*grid)
    wref = nirspec.nrs_wcs_set_input(imref, 0)
    raref, decref, lamref = wref(*grid)

    # equal_nan is used here as many of the entries are nan.
    # The domain is defined but it is only a few entries in there that are valid
    # as it is a curved narrow slit.
    assert_allclose(ra, raref, equal_nan=True)
    assert_allclose(dec, decref, equal_nan=True)
    assert_allclose(lam, lamref, equal_nan=True)
Exemple #6
0
def test_nirspec_wcs(_jail, rtdata, test_id, input_file, truth_file):
    """
        Test of the AssignWcs step on 4 different NIRSpec exposures:
        1) IFU NRS1 exposure,
        2) IFU NRS1 exposure with FILTER=OPAQUE,
        3) IFU NRS2 exposure, and
        4) FS NRS1 exposure with 4 slits.
    """

    # Get the input and truth files
    rtdata.get_data('nirspec/test_wcs/' + input_file)
    rtdata.get_truth('truth/test_nirspec_wcs/' + truth_file)

    # Run the AssignWcs step
    result = AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')
    result.close()

    # Open the output and truth files
    im = ImageModel(result.meta.filename)
    im_ref = ImageModel(truth_file)

    if result.meta.exposure.type == 'NRS_FIXEDSLIT':

        # Loop over the 4 slit instances
        for slit in ['S200A1', 'S200A2', 'S400A1', 'S1600A1']:

            # Create WCS objects for each image
            wcs = nirspec.nrs_wcs_set_input(im, slit)
            wcs_ref = nirspec.nrs_wcs_set_input(im_ref, slit)

            # Compute RA, Dec, lambda values for each image array
            grid = grid_from_bounding_box(wcs.bounding_box)
            ra, dec, lam = wcs(*grid)
            ra_ref, dec_ref, lam_ref = wcs_ref(*grid)

            # Compare the sky coordinates
            assert_allclose(ra, ra_ref, equal_nan=True)
            assert_allclose(dec, dec_ref, equal_nan=True)
            assert_allclose(lam, lam_ref, equal_nan=True)

    else:

        # Create WCS objects for each image
        wcs = nirspec.nrs_wcs_set_input(im, 0)
        wcs_ref = nirspec.nrs_wcs_set_input(im_ref, 0)

        # Compute RA, Dec, lambda values for each image array
        grid = grid_from_bounding_box(wcs.bounding_box)
        ra, dec, lam = wcs(*grid)
        ra_ref, dec_ref, lam_ref = wcs_ref(*grid)

        # Compare the sky coordinates
        # equal_nan is used, because many of the entries are NaN,
        # due to the bounding_box being rectilinear while the
        # defined spectral traces are curved
        assert_allclose(ra, ra_ref, equal_nan=True)
        assert_allclose(dec, dec_ref, equal_nan=True)
        assert_allclose(lam, lam_ref, equal_nan=True)
def test_nirspec_ifu_against_esa(wcs_ifu_grating):
    """
    Test Nirspec IFU mode using CV3 reference files.
    """
    ref = fits.open(
        get_file_path(
            'Trace_IFU_Slice_00_SMOS-MOD-G1M-17-5344175105_30192_JLAB88.fits'))

    # Test NRS1
    pyw = astwcs.WCS(ref['SLITY1'].header)
    # Test evaluating the WCS (slice 0)
    im, refs = wcs_ifu_grating("G140M", "OPAQUE")
    w0 = nirspec.nrs_wcs_set_input(im, 0)

    # get positions within the slit and the corresponding lambda
    slit1 = ref['SLITY1'].data  # y offset on the slit
    lam = ref['LAMBDA1'].data
    # filter out locations outside the slit
    cond = np.logical_and(slit1 < .5, slit1 > -.5)
    y, x = cond.nonzero()  # 0-based

    x, y = pyw.wcs_pix2world(x, y, 0)
    # The pipeline accepts 0-based coordinates
    x -= 1
    y -= 1
    sca2world = w0.get_transform('sca', 'msa_frame')
    _, slit_y, lp = sca2world(x, y)

    lp *= 10**-6
    assert_allclose(lp, lam[cond], atol=1e-13)
Exemple #8
0
def xytov2v3l(x, y, file):
    im = datamodels.ImageModel(file)

    nslice = 30
    # Big structure to save all the returned values
    v2all = np.zeros([len(x), nslice])
    v3all = np.zeros([len(x), nslice])
    lamall = np.zeros([len(x), nslice])
    slall = np.zeros([len(x), nslice])
    for ii in range(0, nslice):
        xform = (nirspec.nrs_wcs_set_input(im, ii)).get_transform(
            'detector', 'v2v3')
        v2all[:, ii], v3all[:, ii], lamall[:, ii] = xform(x, y)
        slall[:, ii] = ii

    # slice all is nan where v2all is nan
    v2_1d = v2all.reshape(-1)
    sl_1d = slall.reshape(-1)
    finite = (np.isfinite(v2_1d))
    indx = (np.where(finite == False))[0]
    if (len(indx) > 0):
        sl_1d[indx] = np.nan

    # Element 1330000 should be x=848,y=649, in slice 6 (0-ind)

    v2 = np.nanmedian(v2all, axis=1)
    v3 = np.nanmedian(v3all, axis=1)
    lam = np.nanmedian(lamall, axis=1)
    sl = np.nanmedian(slall, axis=1)

    return v2, v3, lam, sl
Exemple #9
0
def test_shutter_size_on_sky():
    """
    Test the size of a MOS shutter on sky is ~ .2 x .4 arcsec.
    """
    image = create_nirspec_mos_file()
    model = datamodels.ImageModel(image)
    msaconfl = get_file_path('msa_configuration.fits')

    model.meta.instrument.msa_metadata_file = msaconfl
    model.meta.instrument.msa_metadata_id = 12

    refs = create_reference_files(model)

    pipe = nirspec.create_pipeline(model, refs, slit_y_range=(-.5, .5))
    w = wcs.WCS(pipe)
    model.meta.wcs = w
    slit = w.get_transform('slit_frame', 'msa_frame').slits[0]
    wslit = nirspec.nrs_wcs_set_input(model, slit.name)
    virtual_corners_x = [-.5, -.5, .5, .5, -.5]
    virtual_corners_y = [-.5, .5, .5, -.5, -.5]
    input_lam = [2e-6] * 5

    slit2world = wslit.get_transform('slit_frame', 'world')
    ra, dec, lam = slit2world(virtual_corners_x, virtual_corners_y, input_lam)
    sky = coords.SkyCoord(ra * u.deg, dec * u.deg)
    sep_x = sky[0].separation(sky[3]).to(u.arcsec)
    sep_y = sky[0].separation(sky[1]).to(u.arcsec)

    assert sep_x.value > 0.193
    assert sep_x.value < 0.194
    assert sep_y.value > 0.45
    assert sep_y.value < 0.46
    def test_nirspec_ifu_masterbg_user(self):
        """
        Regression test of master background subtraction for NRS IFU when a
        user 1-D spectrum is provided.
        """
        # input file has 2-D background image added to it
        input_file = self.get_data(*self.test_dir, 'prism_sci_bkg_cal.fits')

        # user-provided 1-D background was created from the 2-D background image
        user_background = self.get_data(*self.test_dir, 'prism_bkg_x1d.fits')

        result = MasterBackgroundStep.call(input_file,
                                           user_background=user_background,
                                           save_results=True)

        # Test 2  compare the science  data with no background
        # to the output from the masterBackground Subtraction step
        # background subtracted science image.
        input_sci_cal_file = self.get_data(*self.test_dir,
                                           'prism_sci_cal.fits')
        input_sci_model = datamodels.open(input_sci_cal_file)

        # We don't want the slices gaps to impact the statisitic
        # loop over the 30 Slices
        for i in range(30):
            slice_wcs = nirspec.nrs_wcs_set_input(input_sci_model, i)
            x, y = grid_from_bounding_box(slice_wcs.bounding_box)
            ra, dec, lam = slice_wcs(x, y)
            valid = np.isfinite(lam)
            result_slice_region = result.data[y.astype(int), x.astype(int)]
            sci_slice_region = input_sci_model.data[y.astype(int),
                                                    x.astype(int)]
            sci_slice = sci_slice_region[valid]
            result_slice = result_slice_region[valid]
            sub = result_slice - sci_slice

            # check for outliers in the science image
            sci_mean = np.nanmean(sci_slice)
            sci_std = np.nanstd(sci_slice)
            upper = sci_mean + sci_std * 5.0
            lower = sci_mean - sci_std * 5.0
            mask_clean = np.logical_and(sci_slice < upper, sci_slice > lower)

            sub_mean = np.absolute(np.nanmean(sub[mask_clean]))
            atol = 2.0
            assert_allclose(sub_mean, 0, atol=atol)

        # Test 3 Compare background sutracted science data (results)
        #  to a truth file. This data is MultiSlit data

        input_sci_model.close()
        result_file = result.meta.filename
        truth_file = self.get_data(*self.ref_loc,
                                   'prism_sci_bkg_masterbackgroundstep.fits')

        outputs = [(result_file, truth_file)]
        self.compare_outputs(outputs)
        input_sci_model.close()
        result.close()
Exemple #11
0
def test_nirspec_fixedslit_wcs(rtdata):
    """Test NIRSpec fixed slit wcs"""
    input_file = 'jw00023001001_01101_00001_nrs1_rate.fits'
    rtdata.get_data(f"nirspec/fs/{input_file}")
    AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')

    output = input_file.replace('rate', 'assign_wcs')
    rtdata.output = output

    rtdata.get_truth(f"truth/test_nirspec_wcs/{output}")

    with datamodels.open(rtdata.output) as im, datamodels.open(
            rtdata.truth) as im_truth:
        # Check the 4 science slits
        for slit in ['S200A1', 'S200A2', 'S400A1', 'S1600A1']:
            wcs = nirspec.nrs_wcs_set_input(im, slit)
            wcs_truth = nirspec.nrs_wcs_set_input(im_truth, slit)

            assert_wcs_grid_allclose(wcs, wcs_truth)
Exemple #12
0
def test_nirspec_ifu_wcs(input_file, rtdata):
    """Test NIRSpec IFU wcs"""
    rtdata.get_data(f"nirspec/ifu/{input_file}")

    AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')

    output = input_file.replace('rate.fits', 'assign_wcs.fits')
    rtdata.output = output

    rtdata.get_truth(f"truth/test_nirspec_wcs/{output}")

    with datamodels.open(rtdata.output) as im, datamodels.open(
            rtdata.truth) as im_truth:
        # Test several slices in the IFU, range(30)
        for slice_ in [0, 9, 16, 23, 29]:
            wcs = nirspec.nrs_wcs_set_input(im, slice_)
            wcs_truth = nirspec.nrs_wcs_set_input(im_truth, slice_)

            assert_wcs_grid_allclose(wcs, wcs_truth)
Exemple #13
0
def test_nirspec_nrs1_wcs(_bigdata):
    """

    Regression test of creating a WCS object and doing pixel to sky transformation.

    """
    output_file_base, output_file = add_suffix('nrs1_fs_wcs_output.fits',
                                               'assignwcsstep')

    try:
        os.remove(output_file)
    except:
        pass

    input_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-fs',
                              'jw00023001001_01101_00001_NRS1_ramp_fit.fits')
    ref_file = os.path.join(_bigdata, 'nirspec', 'test_wcs', 'nrs1-fs',
                            'jw00023001001_01101_00001_NRS1_assign_wcs.fits')

    AssignWcsStep.call(input_file,
                       output_file=output_file_base,
                       name='assignwcsstep')
    im = ImageModel(output_file)
    imref = ImageModel(ref_file)
    #ystart = im.meta.subarray.ystart
    #yend = im.meta.subarray.ystart + im.meta.subarray.ysize-1
    #xstart = im.meta.subarray.xstart
    #xend = im.meta.subarray.xstart + im.meta.subarray.xsize -1
    #x, y = np.mgrid[ystart:yend, xstart: xend]
    for slit in ['S200A1', 'S200A2', 'S400A1', 'S1600A1']:
        w = nirspec.nrs_wcs_set_input(im, slit)
        #ra, dec, lam = getattr(im.meta, 'wcs_'+slit)(y, x)
        #raref, decref, lamref = getattr(imref.meta, 'wcs_'+slit)(y, x)
        y, x = np.mgrid[w.bounding_box[1][0]:w.bounding_box[1][1],
                        w.bounding_box[0][0]:w.bounding_box[0][1]]
        ra, dec, lam = w(x, y)
        wref = nirspec.nrs_wcs_set_input(im, slit)
        raref, decref, lamref = wref(x, y)
        utils.assert_allclose(ra, raref, equal_nan=True)
        utils.assert_allclose(dec, decref, equal_nan=True)
        utils.assert_allclose(lam, lamref, equal_nan=True)
Exemple #14
0
def test_nirspec_mos_wcs(rtdata):
    """Test NIRSpec MOS wcs"""
    input_file = 'msa_patt_num.fits'
    # Get MSA meta file
    rtdata.get_data('nirspec/mos/V9621500100101_short_msa.fits')
    rtdata.get_data(f"nirspec/mos/{input_file}")
    AssignWcsStep.call(input_file, save_results=True, suffix='assign_wcs')

    output = input_file.replace('.fits', '_assign_wcs.fits')
    rtdata.output = output

    rtdata.get_truth(f"truth/test_nirspec_wcs/{output}")

    with datamodels.open(rtdata.output) as im, datamodels.open(
            rtdata.truth) as im_truth:
        names = [slit.name for slit in nirspec.get_open_slits(im)]
        for name in names:
            wcs = nirspec.nrs_wcs_set_input(im, name)
            wcs_truth = nirspec.nrs_wcs_set_input(im_truth, name)

            assert_wcs_grid_allclose(wcs, wcs_truth)
Exemple #15
0
def test_in_slice(slice, wcs_ifu_grating, ifu_world_coord):
    """ Test that the number of valid outputs from a slice forward transform
    equals the valid pixels within the slice from the slice backward transform.
    """
    ra_all, dec_all, lam_all = ifu_world_coord
    im, refs = wcs_ifu_grating("G140H", "F100LP")
    slice_wcs = nirspec.nrs_wcs_set_input(im, slice)
    slicer2world = slice_wcs.get_transform('slicer', 'world')
    detector2slicer = slice_wcs.get_transform('detector', 'slicer')
    x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
    onslice_ind = in_ifu_slice(slice_wcs, ra_all, dec_all, lam_all)
    slx, sly, sllam = slicer2world.inverse(ra_all, dec_all, lam_all)
    xinv, yinv = detector2slicer.inverse(slx[onslice_ind], sly[onslice_ind],
                                         sllam[onslice_ind])

    r, d, _ = slice_wcs(x, y)
    assert r[~np.isnan(r)].size == xinv.size
Exemple #16
0
def ifu_world_coord(wcs_ifu_grating):
    """ Return RA, DEC, LAM for all slices in the NRS IFU."""
    ra_all = []
    dec_all = []
    lam_all = []
    im, refs = wcs_ifu_grating(grating="G140H", filter="F100LP")
    for sl in range(30):
        slice_wcs = nirspec.nrs_wcs_set_input(im, sl)
        x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
        r, d, lam = slice_wcs(x, y)
        ra_all.append(r)
        dec_all.append(d)
        lam_all.append(lam)
    ra_all = np.concatenate([r.flatten() for r in ra_all])
    dec_all = np.concatenate([r.flatten() for r in dec_all])
    lam_all = np.concatenate([r.flatten() for r in lam_all])
    return ra_all, dec_all, lam_all
Exemple #17
0
def test_nirspec_fs_esa():
    """
    Test Nirspec FS mode using build 6 reference files.
    """
    # Test creating the WCS
    filename = create_nirspec_fs_file(grating="G140M", filter="F100LP")
    im = datamodels.ImageModel(filename)
    im.meta.filename = "test_fs.fits"
    refs = create_reference_files(im)

    pipe = nirspec.create_pipeline(im, refs, slit_y_range=[-.5, .5])
    w = wcs.WCS(pipe)
    im.meta.wcs = w
    # Test evaluating the WCS
    w1 = nirspec.nrs_wcs_set_input(im, "S200A1")

    ref = fits.open(
        get_file_path(
            'Trace_SLIT_A_200_1_V84600010001P0000000002101_39547_JLAB88.fits'))
    pyw = astwcs.WCS(ref[1].header)

    # get positions within the slit and the coresponding lambda
    slit1 = ref[5].data  # y offset on the slit
    lam = ref[4].data

    # filter out locations outside the slit
    cond = np.logical_and(slit1 < .5, slit1 > -.5)
    y, x = cond.nonzero()  # 0-based

    x, y = pyw.wcs_pix2world(x, y, 0)
    # The pipeline works with 0-based coordinates
    x -= 1
    y -= 1

    sca2world = w1.get_transform('sca', 'v2v3')
    ra, dec, lp = sca2world(x, y)
    # w1 now outputs in microns hence the 1e6 factor
    lp *= 1e-6
    lam = lam[cond]
    nan_cond = ~np.isnan(lp)
    assert_allclose(lp[nan_cond], lam[nan_cond], atol=10**-13)
    ref.close()
Exemple #18
0
    def test_nirspec_ifu_masterbg_user(self):
        """
        Regression test of master background subtraction for NRS IFU when a
        user 1-D spectrum is provided.
        """
        # input file has 2-D background image added to it
        input_file = self.get_data(*self.test_dir, 'prism_sci_bkg_cal.fits')

        # user-provided 1-D background was created from the 2-D background image
        user_background = self.get_data(*self.test_dir, 'prism_bkg_x1d.fits')

        result = MasterBackgroundStep.call(input_file,
                                           user_background=user_background,
                                           save_results=True)

        # Test 1 compare extracted spectra data with
        # no background added to extracted spectra from the output
        # from MasterBackground subtraction. First cube_build has to be run
        # on the data.
        result_s3d = CubeBuildStep.call(result)
        # run 1-D extract on results from MasterBackground step
        result_1d = Extract1dStep.call(result_s3d, subtract_background=False)

        # get the 1-D extracted spectrum from the science data in truth directory
        input_sci_1d_file = self.get_data(*self.ref_loc, 'prism_sci_extract1d.fits')
        sci_1d = datamodels.open(input_sci_1d_file)

        # read in the valid wavelengths of the user-1d
        user_background_model = datamodels.open(user_background)
        user_wave = user_background_model.spec[0].spec_table['wavelength']
        user_flux = user_background_model.spec[0].spec_table['flux']
        user_wave_valid = np.where(user_flux > 0)
        min_user_wave = np.amin(user_wave[user_wave_valid])
        max_user_wave = np.amax(user_wave[user_wave_valid])
        user_background_model.close()
        # find the waverange covered by both user and science
        sci_spec_1d = sci_1d.spec[0].spec_table['flux']
        sci_spec_wave = sci_1d.spec[0].spec_table['wavelength']

        result_spec_1d = result_1d.spec[0].spec_table['flux']

        sci_wave_valid = np.where(sci_spec_1d > 0)
        min_wave = np.amin(sci_spec_wave[sci_wave_valid])
        max_wave = np.amax(sci_spec_wave[sci_wave_valid])
        if min_user_wave > min_wave:
            min_wave = min_user_wave
        if max_user_wave < max_wave:
            max_wave = max_user_wave

        sub_spec = sci_spec_1d - result_spec_1d
        valid = np.where(np.logical_and(sci_spec_wave > min_wave, sci_spec_wave < max_wave))
        sub_spec = sub_spec[valid]
        sub_spec = sub_spec[1:-2]  # endpoints are wacky

        mean_sub = np.absolute(np.nanmean(sub_spec))
        atol = 5.0
        assert_allclose(mean_sub, 0, atol=atol)

        # Test 2  compare the science  data with no background
        # to the output from the masterBackground Subtraction step
        # background subtracted science image.
        input_sci_cal_file = self.get_data(*self.test_dir,
                                            'prism_sci_cal.fits')
        input_sci_model = datamodels.open(input_sci_cal_file)

        # We don't want the slices gaps to impact the statisitic
        # loop over the 30 Slices
        for i in range(30):
            slice_wcs = nirspec.nrs_wcs_set_input(input_sci_model, i)
            x, y = grid_from_bounding_box(slice_wcs.bounding_box)
            ra, dec, lam = slice_wcs(x, y)
            valid = np.isfinite(lam)
            result_slice_region = result.data[y.astype(int), x.astype(int)]
            sci_slice_region = input_sci_model.data[y.astype(int),
                                                    x.astype(int)]
            sci_slice = sci_slice_region[valid]
            result_slice = result_slice_region[valid]
            sub = result_slice - sci_slice

            # check for outliers in the science image
            sci_mean = np.nanmean(sci_slice)
            sci_std = np.nanstd(sci_slice)
            upper = sci_mean + sci_std*5.0
            lower = sci_mean - sci_std*5.0
            mask_clean = np.logical_and(sci_slice < upper, sci_slice > lower)

            sub_mean = np.absolute(np.nanmean(sub[mask_clean]))
            atol = 2.0
            assert_allclose(sub_mean, 0, atol=atol)

        # Test 3 Compare background sutracted science data (results)
        #  to a truth file. This data is MultiSlit data

        input_sci_model.close()
        result_file = result.meta.filename
        truth_file = self.get_data(*self.ref_loc,
                                  'prism_sci_bkg_masterbackgroundstep.fits')

        outputs = [(result_file, truth_file)]
        self.compare_outputs(outputs)
        input_sci_model.close()
        result.close()
Exemple #19
0
def compare_wcs(infile_name,
                esa_files_path,
                msa_conf_name,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                mode_used=None,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the ESA intermediary files.

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        esa_files_path: str, full path of where to find all ESA intermediary products to make comparisons for the tests
        msa_conf_name: str, full path where to find the shutter configuration file
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and ESA file
        mode_used: string, mode used in the PTT configuration file
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements captured in this variable

    """

    log_msgs = []

    # get grating and filter info from the rate file header
    if mode_used is not None and mode_used == "MOS_sim":
        infile_name = infile_name.replace("assign_wcs", "extract_2d")
    msg = 'wcs validation test infile_name= ' + infile_name
    print(msg)
    log_msgs.append(msg)
    det = fits.getval(infile_name, "DETECTOR", 0)
    lamp = fits.getval(infile_name, "LAMP", 0)
    grat = fits.getval(infile_name, "GRATING", 0)
    filt = fits.getval(infile_name, "FILTER", 0)
    msametfl = fits.getval(infile_name, "MSAMETFL", 0)
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)

    # check that shutter configuration file in header is the same as given in PTT_config file
    if msametfl != os.path.basename(msa_conf_name):
        msg = "* WARNING! MSA config file name given in PTT_config file does not match the MSAMETFL keyword in main header.\n"
        print(msg)
        log_msgs.append(msg)

    # copy the MSA shutter configuration file into the pytest directory
    try:
        subprocess.run(["cp", msa_conf_name, "."])
    except FileNotFoundError:
        msg1 = " * PTT is not able to locate the MSA shutter configuration file. Please make sure that the msa_conf_name variable in"
        msg2 = "   the PTT_config.cfg file is pointing exactly to where the fits file exists (i.e. full path and name). "
        msg3 = "   -> The WCS test is now set to skip and no plots will be generated. "
        print(msg1)
        print(msg2)
        print(msg3)
        log_msgs.append(msg1)
        log_msgs.append(msg2)
        log_msgs.append(msg3)
        FINAL_TEST_RESULT = "skip"
        return FINAL_TEST_RESULT, log_msgs

    # get shutter info from metadata
    shutter_info = fits.getdata(
        msa_conf_name, extname="SHUTTER_INFO")  # this is generally ext=2
    pslit = shutter_info.field("slitlet_id")
    quad = shutter_info.field("shutter_quadrant")
    row = shutter_info.field("shutter_row")
    col = shutter_info.field("shutter_column")
    msg = 'Using this MSA shutter configuration file: ' + msa_conf_name
    print(msg)
    log_msgs.append(msg)

    # get the datamodel from the assign_wcs output file
    if mode_used is None or mode_used != "MOS_sim":
        img = datamodels.ImageModel(infile_name)
        # these commands only work for the assign_wcs ouput file
        # loop over the slits
        #slits_list = nirspec.get_open_slits(img)   # this function returns all open slitlets as defined in msa meta file,
        # however, some of them may not be projected on the detector, and those are later removed from the list of open
        # slitlets. To get the open and projected on the detector slitlets we use the following:
        slits_list = img.meta.wcs.get_transform('gwa', 'slit_frame').slits
        #print ('Open slits: ', slits_list, '\n')

        if debug:
            print("Instrument Configuration")
            print("Detector: {}".format(img.meta.instrument.detector))
            print("GWA: {}".format(img.meta.instrument.grating))
            print("Filter: {}".format(img.meta.instrument.filter))
            print("Lamp: {}".format(img.meta.instrument.lamp_state))
            print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt))
            print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt))
            print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt))

    elif mode_used == "MOS_sim":
        # this command works for the extract_2d and flat_field output files
        model = datamodels.MultiSlitModel(infile_name)
        slits_list = model.slits

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    # loop over the slices
    for slit in slits_list:
        name = slit.name
        msg = "\nWorking with slit: " + str(name)
        print(msg)
        log_msgs.append(msg)

        # get the right index in the list of open shutters
        pslit_list = pslit.tolist()
        slitlet_idx = pslit_list.index(int(name))

        # Get the ESA trace
        #raw_data_root_file = "NRSV96215001001P0000000002103_1_491_SE_2016-01-24T01h25m07.cts.fits" # testing only
        _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
        )
        msg = "Using this raw data file to find the corresponding ESA file: " + raw_data_root_file
        print(msg)
        log_msgs.append(msg)
        q, r, c = quad[slitlet_idx], row[slitlet_idx], col[slitlet_idx]
        msg = "Pipeline shutter info:   quadrant= " + str(
            q) + "   row= " + str(r) + "   col=" + str(c)
        print(msg)
        log_msgs.append(msg)
        specifics = [q, r, c]
        esafile = auxfunc.get_esafile(esa_files_path, raw_data_root_file,
                                      "MOS", specifics)
        #esafile = "/Users/pena/Documents/PyCharmProjects/nirspec/pipeline/src/sandbox/zzzz/Trace_MOS_3_319_013_V96215001001P0000000002103_41543_JLAB88.fits"  # testing only

        # skip the test if the esafile was not found
        if "ESA file not found" in esafile:
            msg1 = " * compare_wcs_mos.py is exiting because the corresponding ESA file was not found."
            msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
            print(msg1)
            print(msg2)
            log_msgs.append(msg1)
            log_msgs.append(msg2)
            FINAL_TEST_RESULT = "skip"
            return FINAL_TEST_RESULT, log_msgs

        # Open the trace in the esafile
        if len(esafile) == 2:
            print(len(esafile[-1]))
            if len(esafile[-1]) == 0:
                esafile = esafile[0]
        msg = "Using this ESA file: \n" + str(esafile)
        print(msg)
        log_msgs.append(msg)
        with fits.open(esafile) as esahdulist:
            print("* ESA file contents ")
            esahdulist.info()
            esa_shutter_i = esahdulist[0].header['SHUTTERI']
            esa_shutter_j = esahdulist[0].header['SHUTTERJ']
            esa_quadrant = esahdulist[0].header['QUADRANT']
            if debug:
                msg = "ESA shutter info:   quadrant=" + esa_quadrant + "   shutter_i=" + esa_shutter_i + "   shutter_j=" + esa_shutter_j
                print(msg)
                log_msgs.append(msg)
            # first check if ESA shutter info is the same as pipeline
            msg = "For slitlet" + str(name)
            print(msg)
            log_msgs.append(msg)
            if q == esa_quadrant:
                msg = "\n -> Same quadrant for pipeline and ESA data: " + str(
                    q)
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of quadrant for pipeline and ESA data: " + str(
                    q) + esa_quadrant
                print(msg)
                log_msgs.append(msg)
            if r == esa_shutter_i:
                msg = "\n -> Same row for pipeline and ESA data: " + str(r)
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of row for pipeline and ESA data: " + str(
                    r) + esa_shutter_i
                print(msg)
                log_msgs.append(msg)
            if c == esa_shutter_j:
                msg = "\n -> Same column for pipeline and ESA data: " + str(
                    c) + "\n"
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of column for pipeline and ESA data: " + str(
                    c) + esa_shutter_j + "\n"
                print(msg)
                log_msgs.append(msg)

            # Assign variables according to detector
            skipv2v3test = True
            if det == "NRS1":
                try:
                    esa_flux = fits.getdata(esafile, "DATA1")
                    esa_wave = fits.getdata(esafile, "LAMBDA1")
                    esa_slity = fits.getdata(esafile, "SLITY1")
                    esa_msax = fits.getdata(esafile, "MSAX1")
                    esa_msay = fits.getdata(esafile, "MSAY1")
                    pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X1")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y1")
                        skipv2v3test = False
                    except KeyError:
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except KeyError:
                    msg = "PTT did not find ESA extensions that match detector NRS1, skipping test for this slitlet..."
                    print(msg)
                    log_msgs.append(msg)
                    continue

            if det == "NRS2":
                try:
                    esa_flux = fits.getdata(esafile, "DATA2")
                    esa_wave = fits.getdata(esafile, "LAMBDA2")
                    esa_slity = fits.getdata(esafile, "SLITY2")
                    esa_msax = fits.getdata(esafile, "MSAX2")
                    esa_msay = fits.getdata(esafile, "MSAY2")
                    pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X2")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y2")
                        skipv2v3test = False
                    except KeyError:
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except KeyError:
                    msg = "PTT did not find ESA extensions that match detector NRS2, skipping test for this slitlet..."
                    print(msg)
                    log_msgs.append(msg)
                    continue

        # get the WCS object for this particular slit
        if mode_used is None or mode_used != "MOS_sim":
            try:
                wcs_slice = nirspec.nrs_wcs_set_input(img, name)
            except:
                ValueError
                msg = "* WARNING: Slitlet " + name + " was not found in the model. Skipping test for this slitlet."
                print(msg)
                log_msgs.append(msg)
                continue
        elif mode_used == "MOS_sim":
            wcs_slice = model.slits[0].wcs

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slice)

        # The WCS object attribute bounding_box shows all valid inputs, i.e. the actual area of the data according
        # to the slice. Inputs outside of the bounding_box return NaN values.
        #bbox = wcs_slice.bounding_box
        #print('wcs_slice.bounding_box: ', wcs_slice.bounding_box)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        #print("Avalable frames: ", wcs_slice.available_frames)

        if mode_used is None or mode_used != "MOS_sim":
            if debug:
                # To get specific pixel values use following syntax:
                det2slit = wcs_slice.get_transform('detector', 'slit_frame')
                slitx, slity, lam = det2slit(700, 1080)
                print("slitx: ", slitx)
                print("slity: ", slity)
                print("lambda: ", lam)

            if debug:
                # The number of inputs and outputs in each frame can vary. This can be checked with:
                print('Number on inputs: ', det2slit.n_inputs)
                print('Number on outputs: ', det2slit.n_outputs)

        # Create x, y indices using the Trace WCS
        pipey, pipex = np.mgrid[:esa_wave.shape[0], :esa_wave.shape[1]]
        esax, esay = pyw.all_pix2world(pipex, pipey, 0)

        if det == "NRS2":
            msg = "NRS2 needs a flip"
            print(msg)
            log_msgs.append(msg)
            esax = 2049 - esax
            esay = 2049 - esay

        # Compute pipeline RA, DEC, and lambda
        slitlet_test_result_list = []
        pra, pdec, pwave = wcs_slice(
            esax - 1, esay - 1
        )  # => RETURNS: RA, DEC, LAMBDA (lam *= 10**-6 to convert to microns)
        pwave *= 10**-6
        # calculate and print statistics for slit-y and x relative differences
        slitlet_name = repr(r) + "_" + repr(c)
        tested_quantity = "Wavelength Difference"
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats_strings = rel_diff_pwave_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pwave_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # get the transforms for pipeline slit-y
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(esax - 1, esay - 1)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff,
            esa_slity,
            esa_slity,
            slity,
            tested_quantity,
            abs=False)
        # calculate and print statistics for slit-y and x absolute differences
        #rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=True)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats_strings = rel_diff_pslity_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pslity_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slice.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(
            esax - 1, esay - 1
        )  # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats_strings = reldiffpmsax_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsax_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats_strings = reldiffpmsay_data
        for msg in print_stats_strings:
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsay_stats[1],
                                              threshold_diff)
        msg = 'Result for test of ' + tested_quantity + ': ' + result
        print(msg)
        log_msgs.append(msg)
        slitlet_test_result_list.append({tested_quantity: result})

        # V2 and V3
        if not skipv2v3test:
            detector2v2v3 = wcs_slice.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(
                esax - 1, esay - 1
            )  # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            # converting to degrees to compare with ESA
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats_strings = reldiffpv2_data
            for msg in print_stats_strings:
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv2_stats[1],
                                                  threshold_diff)
            msg = 'Result for test of ' + tested_quantity + ': ' + result
            print(msg)
            log_msgs.append(msg)
            slitlet_test_result_list.append({tested_quantity: result})
            tested_quantity = "V3 difference"
            # converting to degrees to compare with ESA
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats_strings = reldiffpv3_data
            for msg in print_stats_strings:
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv3_stats[1],
                                                  threshold_diff)
            msg = 'Result for test of ' + tested_quantity + ': ' + result
            print(msg)
            log_msgs.append(msg)
            slitlet_test_result_list.append({tested_quantity: result})

        total_test_result[slitlet_name] = slitlet_test_result_list

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            basenameinfile_name = os.path.basename(infile_name)
            main_title = filt + "   " + grat + "   SLITLET=" + slitlet_name + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate them
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{ESA}) / \lambda_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_wave_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{ESA}$)/slit_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_slitY_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{ESA}$)/MSA_x$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_MSAx_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{ESA}$)/MSA_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    slitlet_name + "_" + det + "_rel_MSAy_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{ESA}$)/V2$_{ESA}$", "N"
                hist_data = notnan_reldiffpv2
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        slitlet_name + "_" + det + "_rel_V2_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{ESA}$)/V3$_{ESA}$", "N"
                hist_data = notnan_reldiffpv3
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        slitlet_name + "_" + det + "_rel_V3_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # remove the copy of the MSA shutter configuration file
    subprocess.run(["rm", msametfl])

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testlist in total_test_result.items():
        for tdict in testlist:
            for t, tr in tdict.items():
                if tr == "FAILED":
                    FINAL_TEST_RESULT = "FAILED"
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  FAILED."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    FINAL_TEST_RESULT = "PASSED"
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  PASSED."
                    print(msg)
                    log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
Exemple #20
0
def extract2d(input_model, which_subarray=None):
    exp_type = input_model.meta.exposure.type.upper()
    log.info('EXP_TYPE is {0}'.format(exp_type))
    if exp_type in ['NRS_FIXEDSLIT', 'NRS_MSASPEC']:
        if which_subarray is None:
            open_slits = nirspec.get_open_slits(input_model)
        else:
            open_slits = [nirspec.slit_name2id[which_subarray]]
    else:
        # Set the step status to COMPLETE
        output_model.meta.cal_step.extract_2d = 'SKIPPED'
        return input_model

    log.info('open slits {0}'.format(open_slits))

    output_model = models.MultiSlitModel()
    output_model.update(input_model)

    _, wrange = nirspec.spectral_order_wrange_from_model(input_model)

    if exp_type == 'NRS_FIXEDSLIT':
        slit_names = [nirspec.slit_id2name[tuple(slit)] for slit in open_slits]
    else:
        slit_names = [str(slit) for slit in open_slits]
    for slit, slit_name in zip(open_slits, slit_names):
        slit_wcs = nirspec.nrs_wcs_set_input(input_model.meta.wcs, slit[0],
                                             slit[1], wrange)
        if (input_model.meta.subarray.ystart is not None):
            xlo, xhi, ylo, yhi = _adjust_subarray(input_model, slit_wcs)
        else:
            log.info('Subarray ystart metadata value not found')
            xlo, xhi = slit_wcs.domain[0]['lower'], slit_wcs.domain[0]['upper']
            ylo, yhi = slit_wcs.domain[1]['lower'], slit_wcs.domain[1]['upper']

        log.info('Name of subarray extracted: %s', slit_name)
        log.info('Subarray x-extents are: %s %s', xlo, xhi)
        log.info('Subarray y-extents are: %s %s', ylo, yhi)

        ext_data = input_model.data[ylo:yhi + 1, xlo:xhi + 1].copy()
        ext_err = input_model.err[ylo:yhi + 1, xlo:xhi + 1].copy()
        ext_dq = input_model.dq[ylo:yhi + 1, xlo:xhi + 1].copy()
        new_model = models.ImageModel(data=ext_data, err=ext_err, dq=ext_dq)
        shape = ext_data.shape
        domain = [{
            'lower': -0.5,
            'upper': shape[1] + 0.5,
            'includes_lower': True,
            'includes_upper': False
        }, {
            'lower': -0.5,
            'upper': shape[0] + 0.5,
            'includes_lower': True,
            'includes_upper': False
        }]
        slit_wcs.domain = domain
        new_model.meta.wcs = slit_wcs
        output_model.slits.append(new_model)
        # set x/ystart values relative to full detector space, so need
        # to account for x/ystart values of input if it's a subarray
        nslit = len(output_model.slits) - 1
        output_model.slits[nslit].name = slit_name
        output_model.slits[
            nslit].xstart = input_model.meta.subarray.xstart + xlo
        output_model.slits[nslit].xsize = xhi - xlo + 1
        output_model.slits[
            nslit].ystart = input_model.meta.subarray.ystart + ylo
        output_model.slits[nslit].ysize = yhi - ylo + 1
    del input_model
    #del output_model.meta.wcs
    # Set the step status to COMPLETE
    output_model.meta.cal_step.extract_2d = 'COMPLETE'
    return output_model
Exemple #21
0
def FindFootPrintNIRSPEC(self, input, this_channel):
    #********************************************************************************
    """
    Short Summary
    -------------
    For each slice find:
    a. the min and max spatial coordinates (alpha,beta) or (V2-v3) depending on coordinate system. 
      axis a = naxis 1, axis b = naxis2 
    b. min and max wavelength is also determined. , beta and lambda for those slices


    Parameters
    ----------
    input: input model (or file) 
    this_channel: channel working with
    

    Returns
    -------
    min and max spaxial coordinates  and wavelength for channel. 

    """
    # loop over all the region (Slices) in the Channel
    # based on regions mask (indexed by slice number) find all the detector
    # x,y values for slice. Then convert the x,y values to  v2,v3,lambda
    # return the min & max of spatial coords and wavelength  - these are of the pixel centers

    start_slice = 0
    end_slice = 29

    nslices = end_slice - start_slice + 1

    a_slice = np.zeros(nslices * 2)
    b_slice = np.zeros(nslices * 2)
    lambda_slice = np.zeros(nslices * 2)

    regions = list(range(start_slice, end_slice + 1))
    k = 0
    sorder, wrange = nirspec.spectral_order_wrange_from_model(input)
    for i in regions:

        slice_wcs = nirspec.nrs_wcs_set_input(input.meta.wcs, 0, i, wrange)
        b = _domain_to_bounds(slice_wcs.domain)
        y, x = np.mgrid[b[1][0]:b[1][1], b[0][0]:b[0][1]]
        v2, v3, lam = slice_wcs(x, y)

        coord1 = v2 * 60.0
        coord2 = v3 * 60.0

        a_slice[k] = np.nanmin(coord1)
        a_slice[k + 1] = np.nanmax(coord1)

        b_slice[k] = np.nanmin(coord2)
        b_slice[k + 1] = np.nanmax(coord2)

        lambda_slice[k] = np.nanmin(lam)
        lambda_slice[k + 1] = np.nanmax(lam)

        k = k + 2

    a_min = min(a_slice)
    a_max = max(a_slice)

    b_min = min(b_slice)
    b_max = max(b_slice)

    lambda_min = min(lambda_slice)
    lambda_max = max(lambda_slice)

    #    print('max a',a_min,a_max)
    #    print('max b',b_min,b_max)
    #    print('wave',lambda_min,lambda_max)

    return a_min, a_max, b_min, b_max, lambda_min, lambda_max
Exemple #22
0
def do_correction(input_model, pathloss_model):
    """
    Short Summary
    -------------
    Execute all tasks for Path Loss Correction

    Parameters
    ----------
    input_model : data model object
        science data to be corrected

    pathloss_model : pathloss model object
        pathloss correction data

    Returns
    -------
    output_model : data model object
        Science data with pathloss extensions added

    """
    exp_type = input_model.meta.exposure.type
    log.info('Input exposure type is {}'.format(exp_type))
    output_model = input_model.copy()
    if exp_type == 'NRS_MSASPEC':
        slit_number = 0
        for slit in output_model.slits:
            slit_number = slit_number + 1
            log.info('Working on slit {}'.format(slit_number))
            size = slit.data.size
            # That has data.size > 0
            if size > 0:
                # Get centering
                xcenter, ycenter = get_center(exp_type, slit)
                # Calculate the 1-d wavelength and pathloss vectors
                # for the source position
                # Get the aperture from the reference file that matches the slit
                nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
                aperture = get_aperture_from_model(pathloss_model, nshutters)
                if aperture is not None:
                    (wavelength_pointsource,
                     pathloss_pointsource_vector,
                     is_inside_slitlet) = calculate_pathloss_vector(aperture.pointsource_data,
                                                                    aperture.pointsource_wcs,
                                                                    xcenter, ycenter)
                    (wavelength_uniformsource,
                     pathloss_uniform_vector,
                     dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                        aperture.uniform_wcs,
                                                        xcenter, ycenter)
                    if is_inside_slitlet:

                        # Wavelengths in the reference file are in meters,
                        #need them to be in microns
                        wavelength_pointsource *= 1.0e6
                        wavelength_uniformsource *= 1.0e6

                        wavelength_array = slit.wavelength

                        # Compute the pathloss 2D correction
                        if is_pointsource(slit.source_type):
                            pathloss_2d = interpolate_onto_grid(
                                wavelength_array,
                                wavelength_pointsource,
                                pathloss_pointsource_vector)
                        else:
                            pathloss_2d = interpolate_onto_grid(
                                wavelength_array,
                                wavelength_uniformsource,
                                pathloss_uniform_vector)
                        # Apply the pathloss 2D correction and attach to datamodel
                        slit.data /= pathloss_2d
                        slit.err /= pathloss_2d
                        slit.var_poisson /= pathloss_2d**2
                        slit.var_rnoise /= pathloss_2d**2
                        if slit.var_flat is not None and np.size(slit.var_flat) > 0:
                            slit.var_flat /= pathloss_2d**2
                        slit.pathloss = pathloss_2d
                    else:
                        log.warning("Source is outside slit.  Skipping "
                        "pathloss correction for slit {}".format(slit_number))
                else:
                    log.warning("Cannot find matching pathloss model for slit "
                        "with {} shutters, skipping pathloss correction for this "
                        "slit".format(nshutters))
                    continue
            else:
                log.warning("Slit has data size = {}, skipping "
                    "pathloss correction for this slitlet".format(size))
        output_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
        slit_number = 0
        is_inside_slit = True
        # For each slit
        for slit in output_model.slits:
            log.info('Working on slit {}'.format(slit.name))
            slit_number = slit_number + 1
            # Get centering
            xcenter, ycenter = get_center(exp_type, slit)
            # Calculate the 1-d wavelength and pathloss vectors
            # for the source position
            # Get the aperture from the reference file that matches the slit
            aperture = get_aperture_from_model(pathloss_model, slit.name)
            if aperture is not None:
                log.info("Using aperture {}".format(aperture.name))
                (wavelength_pointsource,
                 pathloss_pointsource_vector,
                 is_inside_slit) = calculate_pathloss_vector(aperture.pointsource_data,
                                                             aperture.pointsource_wcs,
                                                             xcenter, ycenter)
                (wavelength_uniformsource,
                 pathloss_uniform_vector,
                 dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                    aperture.uniform_wcs,
                                                    xcenter, ycenter)
                if is_inside_slit:

                    # Wavelengths in the reference file are in meters, need them to be
                    # in microns
                    wavelength_pointsource *= 1.0e6
                    wavelength_uniformsource *= 1.0e6

                    wavelength_array = slit.wavelength

                    # Compute the pathloss 2D correction
                    if is_pointsource(slit.source_type):
                        pathloss_2d = interpolate_onto_grid(
                            wavelength_array,
                            wavelength_pointsource,
                            pathloss_pointsource_vector)
                    else:
                        pathloss_2d = interpolate_onto_grid(
                            wavelength_array,
                            wavelength_uniformsource,
                            pathloss_uniform_vector)
                    # Apply the pathloss 2D correction and attach to datamodel
                    slit.data /= pathloss_2d
                    slit.err /= pathloss_2d
                    slit.var_poisson /= pathloss_2d**2
                    slit.var_rnoise /= pathloss_2d**2
                    if slit.var_flat is not None and np.size(slit.var_flat) > 0:
                        slit.var_flat /= pathloss_2d**2
                    slit.pathloss = pathloss_2d
                else:
                    log.warning("Source is outside slit.  Skipping "
                        "pathloss correction for slit {}".format(slit.name))
            else:
                log.warning("Cannot find matching pathloss model for aperture {} "
                    "skipping pathloss correction for this slit".format(slit.name))
                continue
        output_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type == 'NRS_IFU':
        # IFU targets are always inside slit
        # Get centering
        xcenter, ycenter = get_center(exp_type, None)
        # Calculate the 1-d wavelength and pathloss vectors
        # for the source position
        aperture = pathloss_model.apertures[0]
        (wavelength_pointsource,
         pathloss_pointsource_vector,
         dummy) = calculate_pathloss_vector(aperture.pointsource_data,
                                            aperture.pointsource_wcs,
                                            xcenter, ycenter)
        (wavelength_uniformsource,
         pathloss_uniform_vector,
         dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                            aperture.uniform_wcs,
                                            xcenter, ycenter)
        # Wavelengths in the reference file are in meters, need them to be
        # in microns
        wavelength_pointsource *= 1.0e6
        wavelength_uniformsource *= 1.0e6

        # Create the 2-d pathloss arrays, initialize with NaNs
        wavelength_array = np.zeros(input_model.shape, dtype=np.float32)
        wavelength_array.fill(np.nan)
        for slice in NIRSPEC_IFU_SLICES:
            slice_wcs = nirspec.nrs_wcs_set_input(input_model, slice)
            x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
            xmin = int(x.min())
            xmax = int(x.max())
            ymin = int(y.min())
            ymax = int(y.max())
            ra, dec, wavelength = slice_wcs(x, y)
            wavelength_array[ymin:ymax+1, xmin:xmax+1] = wavelength

        # Compute the pathloss 2D correction
        if is_pointsource(input_model.meta.target.source_type):
            pathloss_2d = interpolate_onto_grid(
                wavelength_array,
                wavelength_pointsource,
                pathloss_pointsource_vector)
        else:
            pathloss_2d = interpolate_onto_grid(
                wavelength_array,
                wavelength_uniformsource,
                pathloss_uniform_vector)
        # Apply the pathloss 2D correction and attach to datamodel
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        if output_model.var_flat is not None and np.size(output_model.var_flat) > 0:
            output_model.var_flat /= pathloss_2d**2
        output_model.pathloss = pathloss_2d

        # This might be useful to other steps
        output_model.wavelength = wavelength_array

        output_model.meta.cal_step.pathloss = 'COMPLETE'

    elif exp_type == 'NIS_SOSS':
        """NIRISS SOSS pathloss correction is basically a correction for the
        flux from the 2nd and 3rd order dispersion that falls outside the
        subarray aperture.  The correction depends
        on the pupil wheel position and column number (or wavelength).  The
        simple option is to do the correction by column number, then the only
        interpolation needed is a 1-d interpolation into the pupil wheel position
        dimension."""

        # Omit correction if this is a TSO observation
        if input_model.meta.visit.tsovisit:
            log.warning("NIRISS SOSS TSO observations skip the pathloss step")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        pupil_wheel_position = input_model.meta.instrument.pupil_position
        if pupil_wheel_position is None:
            log.warning("Unable to get pupil wheel position from PWCPOS keyword "
                "for {}".format(input_model.meta.filename))
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        subarray = input_model.meta.subarray.name
        # Get the aperture from the reference file that matches the subarray
        aperture = get_aperture_from_model(pathloss_model, subarray)
        if aperture is None:
            log.warning("Unable to get Aperture from reference file "
                "for subarray {}".format(subarray))
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        else:
            log.info("Aperture {} selected from reference file".format(aperture.name))
        pathloss_array = aperture.pointsource_data[0]
        nrows, ncols = pathloss_array.shape
        _, data_ncols = input_model.data.shape
        correction = np.ones(data_ncols, dtype=np.float32)
        crpix1 = aperture.pointsource_wcs.crpix1
        crval1 = aperture.pointsource_wcs.crval1
        cdelt1 = aperture.pointsource_wcs.cdelt1
        pupil_wheel_index = crpix1 + (pupil_wheel_position - crval1) / cdelt1 - 1

        if pupil_wheel_index < 0 or pupil_wheel_index > (ncols - 2):
            log.info("Pupil Wheel position outside reference file coverage")
            log.info("Setting pathloss correction to 1.0")
        else:
            ix = int(pupil_wheel_index)
            dx = pupil_wheel_index - ix
            crpix2 = aperture.pointsource_wcs.crpix2
            crval2 = aperture.pointsource_wcs.crval2
            cdelt2 = aperture.pointsource_wcs.cdelt2
            for row in range(data_ncols):
                row_1indexed = row + 1
                refrow_index = math.floor(crpix2 + (row_1indexed - crval2) / cdelt2 - 0.5)
                if refrow_index < 0 or refrow_index > (nrows - 1):
                    correction[row] = 1.0
                else:
                    correction[row] = (1.0 - dx) * pathloss_array[refrow_index, ix] + \
                                      dx * pathloss_array[refrow_index, ix + 1]

        pathloss_2d = np.broadcast_to(correction, input_model.data.shape)
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        if output_model.var_flat is not None and np.size(output_model.var_flat) > 0:
            output_model.var_flat /= pathloss_2d**2
        output_model.pathloss = pathloss_2d

        output_model.meta.cal_step.pathloss = 'COMPLETE'

    return output_model
    def test_nirspec_ifu_masterbg_user(self):
        """
        Regression test of master background subtraction for NRS IFU when a
        user 1-D spectrum is provided.
        """
        # input file has 2-D background image added to it
        input_file = self.get_data(*self.test_dir, 'prism_sci_bkg_cal.fits')

        # user-provided 1-D background was created from the 2-D background image
        user_background = self.get_data(*self.test_dir, 'prism_bkg_x1d.fits')

        result = MasterBackgroundStep.call(input_file,
                                           user_background=user_background,
                                           save_results=True)

        # Test 1 compare extracted spectra data with
        # no background added to extracted spectra from the output
        # from MasterBackground subtraction. First cube_build has to be run
        # on the data.
        result_s3d = CubeBuildStep.call(result)
        # run 1-D extract on results from MasterBackground step
        result_1d = Extract1dStep.call(result_s3d, subtract_background=False)

        # get the 1-D extracted spectrum from the science data in truth directory
        input_sci_1d_file = self.get_data(*self.ref_loc, 'prism_sci_extract1d.fits')
        sci_1d = datamodels.open(input_sci_1d_file)

        # read in the valid wavelengths of the user-1d
        user_background_model = datamodels.open(user_background)
        user_wave = user_background_model.spec[0].spec_table['wavelength']
        user_flux = user_background_model.spec[0].spec_table['net']
        user_wave_valid = np.where(user_flux > 0)
        min_user_wave = np.amin(user_wave[user_wave_valid])
        max_user_wave = np.amax(user_wave[user_wave_valid])
        user_background_model.close()
        # find the waverange covered by both user and science
        sci_spec_1d = sci_1d.spec[0].spec_table['net']
        sci_spec_wave = sci_1d.spec[0].spec_table['wavelength']

        result_spec_1d = result_1d.spec[0].spec_table['net']

        sci_wave_valid = np.where(sci_spec_1d > 0)
        min_wave = np.amin(sci_spec_wave[sci_wave_valid])
        max_wave = np.amax(sci_spec_wave[sci_wave_valid])
        if min_user_wave > min_wave:
            min_wave = min_user_wave
        if max_user_wave < max_wave:
            max_wave = max_user_wave

        sub_spec = sci_spec_1d - result_spec_1d
        valid = np.where(np.logical_and(sci_spec_wave > min_wave, sci_spec_wave < max_wave))
        sub_spec = sub_spec[valid]
        sub_spec = sub_spec[1:-2]  # endpoints are wacky

        mean_sub = np.absolute(np.nanmean(sub_spec))
        atol = 5.0
        assert_allclose(mean_sub, 0, atol=atol)

        # Test 2  compare the science  data with no background
        # to the output from the masterBackground Subtraction step
        # background subtracted science image.
        input_sci_cal_file = self.get_data(*self.test_dir,
                                            'prism_sci_cal.fits')
        input_sci_model = datamodels.open(input_sci_cal_file)

        # We don't want the slices gaps to impact the statisitic
        # loop over the 30 Slices
        for i in range(30):
            slice_wcs = nirspec.nrs_wcs_set_input(input_sci_model, i)
            x, y = grid_from_bounding_box(slice_wcs.bounding_box)
            ra, dec, lam = slice_wcs(x, y)
            valid = np.isfinite(lam)
            result_slice_region = result.data[y.astype(int), x.astype(int)]
            sci_slice_region = input_sci_model.data[y.astype(int),
                                                    x.astype(int)]
            sci_slice = sci_slice_region[valid]
            result_slice = result_slice_region[valid]
            sub = result_slice - sci_slice

            # check for outliers in the science image
            sci_mean = np.nanmean(sci_slice)
            sci_std = np.nanstd(sci_slice)
            upper = sci_mean + sci_std*5.0
            lower = sci_mean - sci_std*5.0
            mask_clean = np.logical_and(sci_slice < upper, sci_slice > lower)

            sub_mean = np.absolute(np.nanmean(sub[mask_clean]))
            atol = 2.0
            assert_allclose(sub_mean, 0, atol=atol)

        # Test 3 Compare background sutracted science data (results)
        #  to a truth file. This data is MultiSlit data

        input_sci_model.close()
        result_file = result.meta.filename
        truth_file = self.get_data(*self.ref_loc,
                                  'prism_sci_bkg_masterbackgroundstep.fits')

        outputs = [(result_file, truth_file)]
        self.compare_outputs(outputs)
        input_sci_model.close()
        result.close()
Exemple #24
0
def run_msa_flagging_testing(input_file, msa_flagging_threshold=99.5, rate_obj=None,
                             stellarity=None, operability_ref=None, source_type=None,
                             save_figs=False, show_figs=True, debug=False):
    """
    This is the validation function for the msa flagging step.
    :param input_file: string, fits file output from the msa_flagging step
    :param msa_flagging_threshold: float, percentage for all slits with more than 100 pixels
    :param rate_obj: object, the stage 1 pipeline output object
    :param stellarity: float, stellarity number fro 0.0 to 1.0
    :param operability_ref: string, msa failed open - operability - reference file
    :param source_type: string, options are point, extended, unknown
    :param save_figs: boolean
    :param show_figs: boolean
    :param debug: boolean
    :return:
        FINAL_TEST_RESULT: boolean, True if smaller than or equal to threshold
        result_msg: string, message with reason for passing, failing, or skipped
        log_msgs: list, diagnostic strings to be printed in log

    """
    # start the list of messages that will be added to the log file
    log_msgs = []

    # start the timer
    msa_flagging_test_start_time = time.time()

    # get the data model
    if isinstance(input_file, str):
        msaflag = datamodels.open(input_file)
    else:
        msaflag = input_file

    if debug:
        print('got MSA flagging datamodel!')

    # set up generals for all the plots
    font = {'weight': 'normal',
            'size': 12}
    matplotlib.rc('font', **font)

    # plot full image
    fig = plt.figure(figsize=(9, 9))
    norm = ImageNormalize(msaflag.data, vmin=0., vmax=50., stretch=AsinhStretch())
    plt.imshow(msaflag.data, norm=norm, aspect=1.0, origin='lower', cmap='viridis')
    # Show and/or save figures
    detector = msaflag.meta.instrument.detector
    datadir = None
    if save_figs:
        file_basename = os.path.basename(input_file.replace("_msa_flagging.fits", ""))
        datadir = os.path.dirname(input_file)
        t = (file_basename, "MSA_flagging_full_detector.png")
        plt_name = "_".join(t)
        plt_name = os.path.join(datadir, plt_name)
        plt.savefig(plt_name)
        print('Figure saved as: ', plt_name)
    if show_figs:
        plt.show()
    plt.close()

    # read in DQ flags from MSA_flagging product
    # find all pixels that have been flagged by this step  as MSA_FAILED_OPEN -> DQ array value = 536870912
    # https://jwst-pipeline.readthedocs.io/en/latest/jwst/references_general/references_general.html?highlight=536870912#data-quality-flags
    dq_flag = 536870912
    msaflag_1d = msaflag.dq.flatten()
    index_opens = np.squeeze(np.asarray(np.where(msaflag_1d & dq_flag)))
    if debug:
        print("DQ array at 167, 1918: ", msaflag.dq[167, 1918])
        # np.set_printoptions(threshold=sys.maxsize)  # print all elements in array
        print("Index where Failed Open shutters exist: ", np.shape(index_opens), index_opens)

    # execute script that creates an MSA metafile for the failed open shutters
    # read operability reference file
    """
    crds_path = os.environ.get('CRDS_PATH')
    if crds_path is None:
        print("(msa_flagging_testing): The environment variable CRDS_PATH is not defined. To set it, follow the "
              "instructions at: \n"
              "                        https://github.com/spacetelescope/nirspec_pipe_testing_tool")
        exit()
        """

    crds_path = "https://jwst-crds.stsci.edu/unchecked_get/references/jwst/"
    op_ref_file = "jwst_nirspec_msaoper_0001.json"

    if operability_ref is None:
        ref_file = os.path.join(crds_path, op_ref_file)
        urllib.request.urlretrieve(ref_file, op_ref_file)
    else:
        op_ref_file = operability_ref

    if "http" not in op_ref_file:
        if not os.path.isfile(op_ref_file):
            result_msg = "Skipping msa_flagging test because the operability reference file does not exist: " + \
                         op_ref_file
            print(result_msg)
            log_msgs.append(result_msg)
            result = 'skip'
            return result, result_msg, log_msgs

    if debug:
        print("Using this operability reference file: ", op_ref_file)

    with open(op_ref_file) as f:
        msaoper_dict = json.load(f)
    msaoper = msaoper_dict["msaoper"]

    # find the failed open shutters
    failedopens = [(c["Q"], c["x"], c["y"]) for c in msaoper if c["state"] == 'open']
    if debug:
        print("Failed Open shutters: ", failedopens)

    # unpack the list of tuples into separate lists for MSA quadrant, row, column locations
    quads, allrows, allcols = zip(*failedopens)

    # stellarity -- internal lamps are uniform illumination, so set to 0
    # ** if considering a point source, need to change this to 1, or actual value if known
    if source_type is None:
        # srctyapt = fits.getval(input_file, 'SRCTYAPT')  # previously used
        srctyapt = msaflag.meta.target.source_type_apt
    else:
        srctyapt = source_type.upper()
    if stellarity is None:
        if "POINT" in srctyapt:
            stellarity = 1.0
        else:
            stellarity = 0.0
    else:
        stellarity = float(stellarity)

    # create MSA metafile with F/O shutters
    if datadir is not None:
        fometafile = os.path.join(datadir, 'fopens_metafile_msa.fits')
    else:
        fometafile = 'fopens_metafile_msa.fits'
    if not os.path.isfile(fometafile):
        pattnum = msaflag.meta.dither.position_number
        create_metafile_fopens(fometafile, allcols, allrows, quads, stellarity, failedopens,
                               pattnum, save_fig=save_figs, show_fig=show_figs, debug=debug)

    # run assign_wcs on the science exposure using F/O metafile
    # change MSA metafile name in header to match the F/O metafile name
    if isinstance(input_file, str):
        rate_file = input_file.replace("msa_flagging", "rate")
        if not os.path.isfile(rate_file):
            # if a _rate.fits file does not exist try the usual name
            rate_file = os.path.join(datadir, 'final_output_caldet1_'+detector+'.fits')
            if not os.path.isfile(rate_file):
                result_msg = "Skipping msa_flagging test because no rate fits file was found in directory: " + datadir
                print(result_msg)
                log_msgs.append(result_msg)
                result = 'skip'
                return result, result_msg, log_msgs
        if debug:
            print("Will run assign_wcs with new Failed Open fits file on this file: ", rate_file)
        rate_mdl = datamodels.ImageModel(rate_file)
    else:
        rate_mdl = rate_obj

    if debug:
        print("MSA metadata file in initial rate file: ", rate_mdl.meta.instrument.msa_metadata_file)

    rate_mdl.meta.instrument.msa_metadata_file = fometafile
    if debug:
        print("New MSA metadata file in rate file: ", rate_mdl.meta.instrument.msa_metadata_file)

    # force the exp_type of this new model to MSA, even if IFU so that nrs_wcs_set_input pipeline function works
    if "ifu" in msaflag.meta.exposure.type.lower():
        rate_mdl.meta.exposure.type = 'NRS_MSASPEC'

    # run assign_wcs; use +/-0.45 for the y-limits because the default is too big (0.6 including buffer)
    stp = AssignWcsStep()
    awcs_fo = stp.call(rate_mdl, slit_y_low=-0.45, slit_y_high=0.45)

    # get the slits from the F/O processing run
    slits_list = awcs_fo.meta.wcs.get_transform('gwa', 'slit_frame').slits

    # prepare arrays to hold info needed for validation test
    allsizes = np.zeros(len(slits_list))
    allchecks = np.zeros(len(slits_list))

    # loop over the slits and compare pixel bounds with the flagged pixels from the original product
    for i, slit in enumerate(slits_list):
        try:
            name = slit.name
        except AttributeError:
            name = i
        print("\nWorking with slit/slice: ", name)
        if "IFU" not in msaflag.meta.exposure.type.upper():
            print("Slit min and max in y direction: ", slit.ymin, slit.ymax)
        # get the WCS object for this particular slit
        wcs_slice = nirspec.nrs_wcs_set_input(awcs_fo, name)
        # get the bounding box for the 2D subwindow, round to nearest integer, and convert to integer
        bbox = np.rint(wcs_slice.bounding_box)
        bboxint = bbox.astype(int)
        print("bounding box rounded to next integer: ", bboxint)
        i1 = bboxint[0, 0]
        i2 = bboxint[0, 1]
        i3 = bboxint[1, 0]
        i4 = bboxint[1, 1]
        # make array of pixel locations within bounding box
        x, y = np.mgrid[i1:i2, i3:i4]
        index_1d = np.ravel_multi_index([[y], [x]], (2048, 2048))
        # get the slity WCS parameter to find which pixels are located in the actual spectrum
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(x, y)
        print("Max value in slity array (ignoring NANs): ", np.nanmax(slity))
        index_trace = np.squeeze(index_1d)[~np.isnan(slity)]
        n_overlap = np.sum(np.isin(index_opens, index_trace))
        overlap_percentage = round(n_overlap/index_trace.size*100., 1)
        if debug:
            print("Size of index_trace= ", index_trace.size)
            print("Size of index_opens=", index_opens.size)
            print("Sum of values found in index_opens and index_trace=", n_overlap)
        msg = 'percentage of F/O trace that was flagged: ' + repr(overlap_percentage)
        print(msg)
        log_msgs.append(msg)
        allchecks[i] = overlap_percentage
        allsizes[i] = index_trace.size

        # show 2D cutouts, with flagged pixels overlaid
        # calculate wavelength, slit_y values for the subwindow
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(x, y)

        # extract & display the F/O 2d subwindows from the msa_flagging sci image
        fig = plt.figure(figsize=(19, 19))
        subwin = msaflag.data[i3:i4, i1:i2].copy()
        # set all pixels outside of the nominal shutter length to 0, inside to 1
        subwin[np.isnan(slity.T)] = 0
        subwin[~np.isnan(slity.T)] = 1
        # find the pixels flagged by the msaflagopen step; set them to 1 and everything else to 0 for ease of display
        subwin_dq = msaflag.dq[i3:i4, i1:i2].copy()
        mask = np.zeros(subwin_dq.shape, dtype=bool)
        mask[np.where(subwin_dq & 536870912)] = True
        subwin_dq[mask] = 1
        subwin_dq[~mask] = 0
        # plot the F/O traces
        vmax = np.max(msaflag.data[i3:i4, i1:i2])
        norm = ImageNormalize(msaflag.data[i3:i4, i1:i2], vmin=0., vmax=vmax, stretch=AsinhStretch())
        plt.imshow(msaflag.data[i3:i4, i1:i2], norm=norm, aspect=10.0, origin='lower', cmap='viridis',
                   label='MSA flagging data')
        plt.imshow(subwin, aspect=20.0, origin='lower', cmap='Reds', alpha=0.3, label='Calculated F/O')
        # overplot the flagged pixels in translucent grayscale
        plt.imshow(subwin_dq, aspect=20.0, origin='lower', cmap='gray', alpha=0.3, label='Pipeline F/O')
        if save_figs:
            t = (file_basename, "FailedOpen_detector", detector, "slit", repr(name) + ".png")
            plt_name = "_".join(t)
            plt_name = os.path.join(datadir, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        plt.close()

    # validation: overlap should be >= msa_flagging_threshold percent for all slits with more than 100 pixels
    FINAL_TEST_RESULT = False
    if not isinstance(msa_flagging_threshold, float):
        msa_flagging_threshold = float(msa_flagging_threshold)
    if (allchecks[allsizes >= 100] >= msa_flagging_threshold).all():
        FINAL_TEST_RESULT = True
    else:
        print("\n * One or more traces show msa_flagging match < ", repr(msa_flagging_threshold))
        print("   See results above per trace. \n")
    if FINAL_TEST_RESULT:
        result_msg = "\n *** Final result for msa_flagging test will be reported as PASSED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)
    else:
        result_msg = "\n *** Final result for msa_flagging test will be reported as FAILED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)

    # end the timer
    msa_flagging_test_end_time = time.time() - msa_flagging_test_start_time
    if msa_flagging_test_end_time >= 60.0:
        msa_flagging_test_end_time = msa_flagging_test_end_time/60.0  # in minutes
        msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                     " minutes to finish."
        if msa_flagging_test_end_time >= 60.0:
            msa_flagging_test_end_time = msa_flagging_test_end_time/60.  # in hours
            msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                         " hours to finish."
    else:
        msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                  " seconds to finish."
    print(msa_flagging_test_tot_time)
    log_msgs.append(msa_flagging_test_tot_time)

    # close the datamodel
    msaflag.close()
    rate_mdl.close()

    return FINAL_TEST_RESULT, result_msg, log_msgs
def do_correction(input_model, pathloss_model):
    """
    Short Summary
    -------------
    Execute all tasks for Path Loss Correction

    Parameters
    ----------
    input_model: data model object
        science data to be corrected

    pathloss_model: pathloss model object
        pathloss correction data

    Returns
    -------
    output_model: data model object
        Science data with pathloss extensions added

    """
    exp_type = input_model.meta.exposure.type
    log.info(exp_type)
    if exp_type == 'NRS_MSASPEC':
        slit_number = 0
        # For each slit
        for slit in input_model.slits:
            slit_number = slit_number + 1
            log.info('Working on slit %d' % slit_number)
            size = slit.data.size
            # That has data.size > 0
            if size > 0:
                # Get centering
                xcenter, ycenter = getCenter(exp_type, slit)
                # Calculate the 1-d wavelength and pathloss vectors
                # for the source position
                # Get the aperture from the reference file that matches the slit
                nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
                aperture = getApertureFromModel(pathloss_model, nshutters)
                if aperture is not None:
                    (wavelength_pointsource,
                     pathloss_pointsource_vector,
                     is_inside_slitlet) = calculate_pathloss_vector(aperture.pointsource_data,
                                                                    aperture.pointsource_wcs,
                                                                    xcenter, ycenter)
                    (wavelength_uniformsource,
                     pathloss_uniform_vector,
                     dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                        aperture.uniform_wcs,
                                                        xcenter, ycenter)
                    if is_inside_slitlet:
                        #
                        # Wavelengths in the reference file are in meters,
                        #need them to be in microns
                        wavelength_pointsource *= 1.0e6
                        wavelength_uniformsource *= 1.0e6
                        slit.pathloss_pointsource = pathloss_pointsource_vector
                        slit.wavelength_pointsource =  wavelength_pointsource
                        slit.pathloss_uniformsource = pathloss_uniform_vector
                        slit.wavelength_uniformsource = wavelength_uniformsource
                        #
                        # Create the 2-d pathloss arrays
                        wavelength_array = slit.wavelength
                        pathloss_pointsource_2d = interpolate_onto_grid(wavelength_array,
                                                                        wavelength_pointsource,
                                                                        pathloss_pointsource_vector)
                        pathloss_uniformsource_2d = interpolate_onto_grid(wavelength_array,
                                                                          wavelength_uniformsource,
                                                                          pathloss_uniform_vector)
                        slit.pathloss_pointsource2d = pathloss_pointsource_2d
                        slit.pathloss_uniformsource2d = pathloss_uniformsource_2d
                    else:
                        log.warning("Source is outside slitlet, skipping pathloss correction for this slitlet")
                else:
                    log.warning("Cannot find matching pathloss model for slit with size %d, skipping pathloss correction for this slitlet" % nshutters)
                    continue
            else:
                log.warning("Slit has data size = {}, skipping pathloss correction for this slitlet".format(size))
        input_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
        slit_number = 0
        is_inside_slit = True
        # For each slit
        for slit in input_model.slits:
            log.info(slit.name)
            slit_number = slit_number + 1
            # Get centering
            xcenter, ycenter = getCenter(exp_type, slit)
            # Calculate the 1-d wavelength and pathloss vectors
            # for the source position
            # Get the aperture from the reference file that matches the slit
            aperture = getApertureFromModel(pathloss_model, slit.name)
            if aperture is not None:
                log.info("Using aperture {0}".format(aperture.name))
                (wavelength_pointsource,
                 pathloss_pointsource_vector,
                 is_inside_slit) = calculate_pathloss_vector(aperture.pointsource_data,
                                                             aperture.pointsource_wcs,
                                                             xcenter, ycenter)
                (wavelength_uniformsource,
                 pathloss_uniform_vector,
                 dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                    aperture.uniform_wcs,
                                                    xcenter, ycenter)
                if is_inside_slit:
                    #
                    # Wavelengths in the reference file are in meters, need them to be
                    # in microns
                    wavelength_pointsource *= 1.0e6
                    wavelength_uniformsource *= 1.0e6
                    slit.pathloss_pointsource = pathloss_pointsource_vector
                    slit.wavelength_pointsource =  wavelength_pointsource
                    slit.pathloss_uniformsource = pathloss_uniform_vector
                    slit.wavelength_uniformsource = wavelength_uniformsource
                    #
                    # Create the 2-d pathloss arrays
                    wavelength_array = slit.wavelength
                    pathloss_pointsource_2d = interpolate_onto_grid(wavelength_array,
                                                                    wavelength_pointsource,
                                                                    pathloss_pointsource_vector)
                    pathloss_uniformsource_2d = interpolate_onto_grid(wavelength_array,
                                                                      wavelength_uniformsource,
                                                                      pathloss_uniform_vector)
                    slit.pathloss_pointsource2d = pathloss_pointsource_2d
                    slit.pathloss_uniformsource2d = pathloss_uniformsource_2d
                else:
                    log.warning("Source is outside slit, skipping pathloss correction for this slit")
            else:
                log.warning("Cannot find matching pathloss model for aperture %s, skipping pathloss correction for this slit" % slit.name)
                continue
        input_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type == 'NRS_IFU':
        # IFU targets are always inside slit
        # Get centering
        xcenter, ycenter = getCenter(exp_type, None)
        # Calculate the 1-d wavelength and pathloss vectors
        # for the source position
        aperture = pathloss_model.apertures[0]
        (wavelength_pointsource,
         pathloss_pointsource_vector,
         dummy) = calculate_pathloss_vector(aperture.pointsource_data,
                                            aperture.pointsource_wcs,
                                            xcenter, ycenter)
        (wavelength_uniformsource,
         pathloss_uniform_vector,
         dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                            aperture.uniform_wcs,
                                            xcenter, ycenter)
        # Wavelengths in the reference file are in meters, need them to be
        # in microns
        wavelength_pointsource *= 1.0e6
        wavelength_uniformsource *= 1.0e6
        input_model.wavelength_pointsource = wavelength_pointsource
        input_model.pathloss_pointsource = pathloss_pointsource_vector
        input_model.wavelength_uniformsource = wavelength_uniformsource
        input_model.pathloss_uniformsource = pathloss_uniform_vector
        #
        # Create the 2-d pathloss arrays, initialize with NaNs
        wavelength_array = np.zeros(input_model.shape, dtype=np.float32)
        wavelength_array.fill(np.nan)
        for slice in NIRSPEC_IFU_SLICES:
            slice_wcs = nirspec.nrs_wcs_set_input(input_model, slice)
            x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
            xmin = int(x.min())
            xmax = int(x.max())
            ymin = int(y.min())
            ymax = int(y.max())
            ra, dec, wavelength = slice_wcs(x, y)
            wavelength_array[ymin:ymax+1, xmin:xmax+1] = wavelength
        pathloss_pointsource_2d = interpolate_onto_grid(wavelength_array,
                                                        wavelength_pointsource,
                                                        pathloss_pointsource_vector)
        pathloss_uniformsource_2d = interpolate_onto_grid(wavelength_array,
                                                          wavelength_uniformsource,
                                                          pathloss_uniform_vector)
        input_model.pathloss_pointsource2d = pathloss_pointsource_2d
        input_model.pathloss_uniformsource2d = pathloss_uniformsource_2d
        #
        # This might be useful to other steps
        input_model.wavelength = wavelength_array
        input_model.meta.cal_step.pathloss = 'COMPLETE'

    return input_model.copy()
Exemple #26
0
def FindFootPrintNIRSPEC(self,input,this_channel):
#********************************************************************************

    """
    Short Summary
    -------------
    For each slice find:
    a. the min and max spatial coordinates (alpha,beta) or (V2-v3) depending on coordinate system. 
      axis a = naxis 1, axis b = naxis2 
    b. min and max wavelength is also determined. , beta and lambda for those slices


    Parameters
    ----------
    input: input model (or file) 
    this_channel: channel working with
    

    Returns
    -------
    min and max spaxial coordinates  and wavelength for channel. 

    """
    # loop over all the region (Slices) in the Channel
    # based on regions mask (indexed by slice number) find all the detector
    # x,y values for slice. Then convert the x,y values to  v2,v3,lambda
    # return the min & max of spatial coords and wavelength  - these are of the pixel centers



    start_slice =0
    end_slice  = 29

    nslices = end_slice - start_slice + 1

    a_slice = np.zeros(nslices*2)
    b_slice = np.zeros(nslices*2)
    lambda_slice = np.zeros(nslices*2)

    regions = list(range(start_slice,end_slice+1))
    k = 0 
    sorder,wrange = nirspec.spectral_order_wrange_from_model(input)
    for i in regions:
        
        slice_wcs = nirspec.nrs_wcs_set_input(input.meta.wcs,0,i,wrange)
        b = _domain_to_bounds(slice_wcs.domain)
        y,x = np.mgrid[b[1][0]:b[1][1],b[0][0]:b[0][1]]
        v2,v3,lam = slice_wcs(x,y)

        coord1 = v2*60.0
        coord2 = v3*60.0

        a_slice[k] = np.nanmin(coord1)
        a_slice[k+1] = np.nanmax(coord1)

        b_slice[k] = np.nanmin(coord2)
        b_slice[k+1] = np.nanmax(coord2)

        lambda_slice[k]= np.nanmin(lam)
        lambda_slice[k+1] = np.nanmax(lam)

        k = k + 2
    
    a_min = min(a_slice)
    a_max = max(a_slice)
            
    b_min = min(b_slice)
    b_max = max(b_slice)

    lambda_min = min(lambda_slice)
    lambda_max = max(lambda_slice)

#    print('max a',a_min,a_max)
#    print('max b',b_min,b_max)
#    print('wave',lambda_min,lambda_max)


    return a_min,a_max,b_min,b_max,lambda_min,lambda_max
Exemple #27
0
def do_correction(input_model, pathloss_model):
    """
    Short Summary
    -------------
    Execute all tasks for Path Loss Correction

    Parameters
    ----------
    input_model : data model object
        science data to be corrected

    pathloss_model : pathloss model object
        pathloss correction data

    Returns
    -------
    output_model : data model object
        Science data with pathloss extensions added

    """
    exp_type = input_model.meta.exposure.type
    log.info(exp_type)
    output_model = input_model.copy()
    if exp_type == 'NRS_MSASPEC':
        slit_number = 0
        for slit in output_model.slits:
            slit_number = slit_number + 1
            log.info('Working on slit {}'.format(slit_number))
            size = slit.data.size
            # That has data.size > 0
            if size > 0:
                # Get centering
                xcenter, ycenter = get_center(exp_type, slit)
                # Calculate the 1-d wavelength and pathloss vectors
                # for the source position
                # Get the aperture from the reference file that matches the slit
                nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
                aperture = get_aperture_from_model(pathloss_model, nshutters)
                if aperture is not None:
                    (wavelength_pointsource,
                     pathloss_pointsource_vector,
                     is_inside_slitlet) = calculate_pathloss_vector(aperture.pointsource_data,
                                                                    aperture.pointsource_wcs,
                                                                    xcenter, ycenter)
                    (wavelength_uniformsource,
                     pathloss_uniform_vector,
                     dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                        aperture.uniform_wcs,
                                                        xcenter, ycenter)
                    if is_inside_slitlet:

                        # Wavelengths in the reference file are in meters,
                        #need them to be in microns
                        wavelength_pointsource *= 1.0e6
                        wavelength_uniformsource *= 1.0e6

                        wavelength_array = slit.wavelength

                        # Compute the pathloss 2D correction
                        if is_pointsource(slit.source_type):
                            pathloss_2d = interpolate_onto_grid(
                                wavelength_array,
                                wavelength_pointsource,
                                pathloss_pointsource_vector)
                        else:
                            pathloss_2d = interpolate_onto_grid(
                                wavelength_array,
                                wavelength_uniformsource,
                                pathloss_uniform_vector)
                        # Apply the pathloss 2D correction and attach to datamodel
                        slit.data /= pathloss_2d
                        slit.err /= pathloss_2d
                        slit.var_poisson /= pathloss_2d**2
                        slit.var_rnoise /= pathloss_2d**2
                        slit.pathloss = pathloss_2d
                    else:
                        log.warning("Source is outside slit.  Skipping "
                        "pathloss correction for slit {}".format(slit_number))
                else:
                    log.warning("Cannot find matching pathloss model for slit "
                        "with {} shutters, skipping pathloss correction for this "
                        "slit".format(nshutters))
                    continue
            else:
                log.warning("Slit has data size = {}, skipping "
                    "pathloss correction for this slitlet".format(size))
        output_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
        slit_number = 0
        is_inside_slit = True
        # For each slit
        for slit in output_model.slits:
            log.info(slit.name)
            slit_number = slit_number + 1
            # Get centering
            xcenter, ycenter = get_center(exp_type, slit)
            # Calculate the 1-d wavelength and pathloss vectors
            # for the source position
            # Get the aperture from the reference file that matches the slit
            aperture = get_aperture_from_model(pathloss_model, slit.name)
            if aperture is not None:
                log.info("Using aperture {}".format(aperture.name))
                (wavelength_pointsource,
                 pathloss_pointsource_vector,
                 is_inside_slit) = calculate_pathloss_vector(aperture.pointsource_data,
                                                             aperture.pointsource_wcs,
                                                             xcenter, ycenter)
                (wavelength_uniformsource,
                 pathloss_uniform_vector,
                 dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                    aperture.uniform_wcs,
                                                    xcenter, ycenter)
                if is_inside_slit:

                    # Wavelengths in the reference file are in meters, need them to be
                    # in microns
                    wavelength_pointsource *= 1.0e6
                    wavelength_uniformsource *= 1.0e6

                    wavelength_array = slit.wavelength

                    # Compute the pathloss 2D correction
                    if is_pointsource(slit.source_type):
                        pathloss_2d = interpolate_onto_grid(
                            wavelength_array,
                            wavelength_pointsource,
                            pathloss_pointsource_vector)
                    else:
                        pathloss_2d = interpolate_onto_grid(
                            wavelength_array,
                            wavelength_uniformsource,
                            pathloss_uniform_vector)
                    # Apply the pathloss 2D correction and attach to datamodel
                    slit.data /= pathloss_2d
                    slit.err /= pathloss_2d
                    slit.var_poisson /= pathloss_2d**2
                    slit.var_rnoise /= pathloss_2d**2
                    slit.pathloss = pathloss_2d
                else:
                    log.warning("Source is outside slit.  Skipping "
                        "pathloss correction for slit {}".format(slit.name))
            else:
                log.warning("Cannot find matching pathloss model for aperture {} "
                    "skipping pathloss correction for this slit".format(slit.name))
                continue
        output_model.meta.cal_step.pathloss = 'COMPLETE'
    elif exp_type == 'NRS_IFU':
        # IFU targets are always inside slit
        # Get centering
        xcenter, ycenter = get_center(exp_type, None)
        # Calculate the 1-d wavelength and pathloss vectors
        # for the source position
        aperture = pathloss_model.apertures[0]
        (wavelength_pointsource,
         pathloss_pointsource_vector,
         dummy) = calculate_pathloss_vector(aperture.pointsource_data,
                                            aperture.pointsource_wcs,
                                            xcenter, ycenter)
        (wavelength_uniformsource,
         pathloss_uniform_vector,
         dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                            aperture.uniform_wcs,
                                            xcenter, ycenter)
        # Wavelengths in the reference file are in meters, need them to be
        # in microns
        wavelength_pointsource *= 1.0e6
        wavelength_uniformsource *= 1.0e6

        # Create the 2-d pathloss arrays, initialize with NaNs
        wavelength_array = np.zeros(input_model.shape, dtype=np.float32)
        wavelength_array.fill(np.nan)
        for slice in NIRSPEC_IFU_SLICES:
            slice_wcs = nirspec.nrs_wcs_set_input(input_model, slice)
            x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
            xmin = int(x.min())
            xmax = int(x.max())
            ymin = int(y.min())
            ymax = int(y.max())
            ra, dec, wavelength = slice_wcs(x, y)
            wavelength_array[ymin:ymax+1, xmin:xmax+1] = wavelength

        # Compute the pathloss 2D correction
        if is_pointsource(input_model.meta.target.source_type):
            pathloss_2d = interpolate_onto_grid(
                wavelength_array,
                wavelength_pointsource,
                pathloss_pointsource_vector)
        else:
            pathloss_2d = interpolate_onto_grid(
                wavelength_array,
                wavelength_uniformsource,
                pathloss_uniform_vector)
        # Apply the pathloss 2D correction and attach to datamodel
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        output_model.pathloss = pathloss_2d

        # This might be useful to other steps
        output_model.wavelength = wavelength_array

        output_model.meta.cal_step.pathloss = 'COMPLETE'

    elif exp_type == 'NIS_SOSS':
        """NIRISS SOSS pathloss correction is basically a correction for the
        flux from the 2nd and 3rd order dispersion that falls outside the
        subarray aperture.  The correction depends
        on the pupil wheel position and column number (or wavelength).  The
        simple option is to do the correction by column number, then the only
        interpolation needed is a 1-d interpolation into the pupil wheel position
        dimension."""

        # Omit correction if this is a TSO observation
        if input_model.meta.visit.tsovisit:
            log.warning("NIRISS SOSS TSO observations skip the pathloss step")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        pupil_wheel_position = input_model.meta.instrument.pupil_position
        if pupil_wheel_position is None:
            log.warning("Unable to get pupil wheel position from PWCPOS keyword "
                "for {}".format(input_model.meta.filename))
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        subarray = input_model.meta.subarray.name
        # Get the aperture from the reference file that matches the subarray
        aperture = get_aperture_from_model(pathloss_model, subarray)
        if aperture is None:
            log.warning("Unable to get Aperture from reference file "
                "for subarray {}".format(subarray))
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        else:
            log.info("Aperture {} selected from reference file".format(aperture.name))
        pathloss_array = aperture.pointsource_data[0]
        nrows, ncols = pathloss_array.shape
        _, data_ncols = input_model.data.shape
        correction = np.ones(data_ncols, dtype=np.float32)
        crpix1 = aperture.pointsource_wcs.crpix1
        crval1 = aperture.pointsource_wcs.crval1
        cdelt1 = aperture.pointsource_wcs.cdelt1
        pupil_wheel_index = crpix1 + (pupil_wheel_position - crval1) / cdelt1 - 1

        if pupil_wheel_index < 0 or pupil_wheel_index > (ncols - 2):
            log.info("Pupil Wheel position outside reference file coverage")
            log.info("Setting pathloss correction to 1.0")
        else:
            ix = int(pupil_wheel_index)
            dx = pupil_wheel_index - ix
            crpix2 = aperture.pointsource_wcs.crpix2
            crval2 = aperture.pointsource_wcs.crval2
            cdelt2 = aperture.pointsource_wcs.cdelt2
            for row in range(data_ncols):
                row_1indexed = row + 1
                refrow_index = math.floor(crpix2 + (row_1indexed - crval2) / cdelt2 - 0.5)
                if refrow_index < 0 or refrow_index > (nrows - 1):
                    correction[row] = 1.0
                else:
                    correction[row] = (1.0 - dx) * pathloss_array[refrow_index, ix] + \
                                      dx * pathloss_array[refrow_index, ix + 1]

        pathloss_2d = np.broadcast_to(correction, input_model.data.shape)
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        output_model.pathloss = pathloss_2d

        output_model.meta.cal_step.pathloss = 'COMPLETE'

    return output_model
Exemple #28
0
def test_functional_ifu_grating():
    """Compare Nirspec instrument model with IDT model for IFU grating."""

    # setup test
    model_file = 'ifu_grating_functional_ESA_v1_20180619.txt'
    hdul = create_nirspec_ifu_file(grating='G395H',
                                   filter='F290LP',
                                   gwa_xtil=0.35986012,
                                   gwa_ytil=0.13448857)
    im = datamodels.ImageModel(hdul)
    refs = create_reference_files(im)
    pipeline = nirspec.create_pipeline(im, refs, slit_y_range=[-0.55, 0.55])
    w = wcs.WCS(pipeline)
    im.meta.wcs = w
    slit_wcs = nirspec.nrs_wcs_set_input(im, 0)  # use slice 0
    ins_file = get_file_path(model_file)
    ins_tab = table.Table.read(ins_file, format='ascii')
    slitx = [0] * 5
    slity = [-.5, -.25, 0, .25, .5]
    lam = np.array([2.9, 3.39, 3.88, 4.37, 5]) * 10**-6
    order, wrange = nirspec.get_spectral_order_wrange(im,
                                                      refs['wavelengthrange'])
    im.meta.wcsinfo.sporder = order
    im.meta.wcsinfo.waverange_start = wrange[0]
    im.meta.wcsinfo.waverange_end = wrange[1]

    # Slit to MSA entrance
    # This includes the Slicer transform and the IFUFORE transform
    slit2msa = slit_wcs.get_transform('slit_frame', 'msa_frame')
    msax, msay, _ = slit2msa(slitx, slity, lam)
    assert_allclose(slitx, ins_tab['xslitpos'])
    assert_allclose(slity, ins_tab['yslitpos'])
    assert_allclose(msax + 0.0073, ins_tab['xmsapos'],
                    rtol=1e-2)  # expected offset
    assert_allclose(msay + 0.0085, ins_tab['ymaspos'],
                    rtol=1e-2)  # expected offset

    # Slicer
    slit2slicer = slit_wcs.get_transform('slit_frame', 'slicer')
    x_slicer, y_slicer, _ = slit2slicer(slitx, slity, lam)

    # MSA exit
    # Applies the IFUPOST transform to coordinates at the Slicer
    with datamodels.IFUPostModel(refs['ifupost']) as ifupost:
        ifupost_transform = nirspec._create_ifupost_transform(ifupost.slice_0)
    x_msa_exit, y_msa_exit = ifupost_transform(x_slicer, y_slicer, lam)
    assert_allclose(x_msa_exit, ins_tab['xmsapos'])
    assert_allclose(y_msa_exit, ins_tab['ymaspos'])

    # Computations are done using the eact form of the equations in the reports
    # Part I of the Forward IFU-POST transform - the linear transform
    xc_out = 0.0487158154447
    yc_out = 0.00856211956976
    xc_in = 0.000355277216
    yc_in = -3.0089012e-05
    theta = np.deg2rad(-0.129043957046)
    factor_x = 0.100989874454
    factor_y = 0.100405184145

    # Slicer coordinates
    xS = 0.000399999989895
    yS = -0.00600000005215

    x = xc_out + factor_x * (+cos(theta) * (xS - xc_in) + sin(theta) *
                             (yS - yc_in))
    y = yc_out + factor_y * (-sin(theta) * (xS - xc_in) + cos(theta) *
                             (yS - yc_in))

    # Forward IFU-POST II part - non-linear transform
    lam = 2.9e-6
    coef_names = [
        f'c{x}_{y}' for x in range(6) for y in range(6) if x + y <= 5
    ]
    y_forw = [
        -82.3492267824, 29234.6982762, -540260.780853, 771881.305018,
        -2563462.26848, 29914272.1164, 4513.04082605, -2212869.44311,
        32875633.0303, -29923698.5288, 27293902.5636, -39820.4434726,
        62431493.9962, -667197265.033, 297253538.182, -1838860.86305,
        -777169857.2, 4514693865.7, 42790637.764, 3596423850.94, -260274017.448
    ]
    y_forw_dist = [
        188531839.97, -43453434864.0, 70807756765.8, -308272809909.0,
        159768473071.0, 9712633344590.0, -11762923852.9, 3545938873190.0,
        -4198643655420.0, 12545642983100.0, -11707051591600.0, 173091230285.0,
        -108534069056000.0, 82893348097600.0, -124708740989000.0,
        2774389757990.0, 1476779720300000.0, -545358301961000.0,
        -93101557994100.0, -7536890639430000.0, 646310545048000.0
    ]
    y_coeff = {}
    for i, coef in enumerate(coef_names):
        y_coeff[coef] = y_forw[i] + lam * y_forw_dist[i]
    poly2d = astmodels.Polynomial2D(5, **y_coeff)
    ifupost_y = poly2d(x, y)
    assert_allclose(ifupost_y, ins_tab['ymaspos'][0])
    assert_allclose(ifupost_y, y_msa_exit[0])

    # reset 'lam'
    lam = np.array([2.9, 3.39, 3.88, 4.37, 5]) * 10**-6

    # Coordinates at Collimator exit
    # Applies the Collimator forward transform to coordinates at the MSA exit
    with datamodels.open(refs['collimator']) as col:
        colx, coly = col.model.inverse(x_msa_exit, y_msa_exit)
    assert_allclose(colx, ins_tab['xcoll'])
    assert_allclose(coly, ins_tab['ycoll'])

    # After applying direcitonal cosines
    dircos = trmodels.Unitless2DirCos()
    xcolDircosi, ycolDircosi, z = dircos(colx, coly)
    assert_allclose(xcolDircosi, ins_tab['xcolDirCosi'])
    assert_allclose(ycolDircosi, ins_tab['ycolDirCosi'])

    # Slit to GWA entrance
    # applies the Collimator forward, Unitless to Directional and 3D Rotation to MSA exit coordinates
    with datamodels.DisperserModel(refs['disperser']) as disp:
        disperser = nirspec.correct_tilt(disp, im.meta.instrument.gwa_xtilt,
                                         im.meta.instrument.gwa_ytilt)
    collimator2gwa = nirspec.collimator_to_gwa(refs, disperser)
    x_gwa_in, y_gwa_in, z_gwa_in = collimator2gwa(x_msa_exit, y_msa_exit)
    assert_allclose(x_gwa_in, ins_tab['xdispIn'])

    # Slit to GWA out
    # Runs slit--> slicer --> msa_exit --> collimator --> dircos --> rotation --> angle_from_grating equation
    slit2gwa = slit_wcs.get_transform('slit_frame', 'gwa')
    x_gwa_out, y_gwa_out, z_gwa_out = slit2gwa(slitx, slity, lam)
    assert_allclose(x_gwa_out, ins_tab['xdispLaw'])
    assert_allclose(y_gwa_out, ins_tab['ydispLaw'])

    # CAMERA entrance (assuming direction is from sky to detector)
    angles = [
        disperser['theta_x'], disperser['theta_y'], disperser['theta_z'],
        disperser['tilt_y']
    ]
    rotation = trmodels.Rotation3DToGWA(angles,
                                        axes_order="xyzy",
                                        name='rotation')
    dircos2unitless = trmodels.DirCos2Unitless()
    gwa2cam = rotation.inverse | dircos2unitless
    x_camera_entrance, y_camera_entrance = gwa2cam(x_gwa_out, y_gwa_out,
                                                   z_gwa_out)
    assert_allclose(x_camera_entrance, ins_tab['xcamCosi'])
    assert_allclose(y_camera_entrance, ins_tab['ycamCosi'])

    # at FPA
    with datamodels.CameraModel(refs['camera']) as camera:
        x_fpa, y_fpa = camera.model.inverse(x_camera_entrance,
                                            y_camera_entrance)
    assert_allclose(x_fpa, ins_tab['xfpapos'])
    assert_allclose(y_fpa, ins_tab['yfpapos'])

    # at SCA
    slit2sca = slit_wcs.get_transform('slit_frame', 'sca')
    x_sca_nrs1, y_sca_nrs1 = slit2sca(slitx, slity, lam)

    # At NRS2
    with datamodels.FPAModel(refs['fpa']) as fpa:
        x_sca_nrs2, y_sca_nrs2 = fpa.nrs2_model.inverse(x_fpa, y_fpa)
    assert_allclose(x_sca_nrs1[:3] + 1, ins_tab['i'][:3])
    assert_allclose(y_sca_nrs1[:3] + 1, ins_tab['j'][:3])
    assert_allclose(x_sca_nrs2[3:] + 1, ins_tab['i'][3:])
    assert_allclose(y_sca_nrs2[3:] + 1, ins_tab['j'][3:])

    # at oteip
    # Goes through slicer, ifufore, and fore transforms
    slit2oteip = slit_wcs.get_transform('slit_frame', 'oteip')
    x_oteip, y_oteip, _ = slit2oteip(slitx, slity, lam)
    assert_allclose(x_oteip, ins_tab['xOTEIP'])
    assert_allclose(y_oteip, ins_tab['yOTEIP'])

    # at v2, v3 [in arcsec]
    slit2v23 = slit_wcs.get_transform('slit_frame', 'v2v3')
    v2, v3, _ = slit2v23(slitx, slity, lam)
    v2 /= 3600
    v3 /= 3600
    assert_allclose(v2, ins_tab['xV2V3'])
    assert_allclose(v3, ins_tab['yV2V3'])
Exemple #29
0
def compare_wcs(infile_name,
                esa_files_path=None,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the ESA intermediary files.

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        esa_files_path: str, full path of where to find all ESA intermediary products to make comparisons for the tests
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and ESA file
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements captured in this variable

    """

    log_msgs = []

    # get grating and filter info from the rate file header
    det = fits.getval(infile_name, "DETECTOR", 0)
    msg = 'infile_name=' + infile_name
    print(msg)
    log_msgs.append(msg)
    lamp = fits.getval(infile_name, "LAMP", 0)
    grat = fits.getval(infile_name, "GRATING", 0)
    filt = fits.getval(infile_name, "FILTER", 0)
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)

    # mapping the ESA slit names to pipeline names
    map_slit_names = {
        'SLIT_A_1600': 'S1600A1',
        'SLIT_A_200_1': 'S200A1',
        'SLIT_A_200_2': 'S200A2',
        'SLIT_A_400': 'S400A1',
        'SLIT_B_200': 'S200B1',
    }

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    # get the datamodel from the assign_wcs output file
    img = datamodels.ImageModel(infile_name)

    # To get the open and projected on the detector
    open_slits = img.meta.wcs.get_transform('gwa', 'slit_frame').slits
    for opslit in open_slits:
        pipeslit = opslit.name
        msg = "\nWorking with slit: " + pipeslit
        print(msg)
        log_msgs.append(msg)

        # Get the ESA trace
        #raw_data_root_file = "NRSV84600010001P0000000002101_4_491_SE_2016-01-17T17h34m08.fits"  # for testing with G140M FULLFRAME
        #raw_data_root_file = "NRSSMOS-MOD-G1H-02-5344031756_1_491_SE_2015-12-10T03h25m56.fits"  # for testing with G140H FULLFRAME
        #raw_data_root_file = "NRSSDRK-ALLSLITS-5345150216_1_491_SE_2015-12-11T15h40m25.fits"  # for testing with G140H ALLSLITS
        #raw_data_root_file = "NRSV84600002001P0000000002101_1_491_SE_2016-01-17T15h09m16.fits"  # for testing with G140M ALLSLITS
        #raw_data_root_file = "NRSV84600004001P0000000002101_1_491_SE_2016-01-17T15h41m16.fits"  # for testing with G235H ALLSLITS
        _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
        )
        specifics = [pipeslit]

        # check if ESA data is not in the regular directory tree, these files are exceptions
        NIDs = ["30055", "30055", "30205", "30133", "30133"]
        special_cutout_files = [
            "NRSSMOS-MOD-G1H-02-5344031756_1_491_SE_2015-12-10T03h25m56.fits",
            "NRSSMOS-MOD-G1H-02-5344031756_1_492_SE_2015-12-10T03h25m56.fits",
            "NRSSMOS-MOD-G2M-01-5344191938_1_491_SE_2015-12-10T19h29m26.fits",
            "NRSSMOS-MOD-G3H-02-5344120942_1_491_SE_2015-12-10T12h18m25.fits",
            "NRSSMOS-MOD-G3H-02-5344120942_1_492_SE_2015-12-10T12h18m25.fits"
        ]
        if raw_data_root_file in special_cutout_files:
            nid = NIDs[special_cutout_files.index(raw_data_root_file)]
            msg = "Using NID = " + nid
            print(msg)
            log_msgs.append(msg)
        else:
            nid = None

        #esafile= "/grp/jwst/wit4/nirspec_vault/prelaunch_data/testing_sets/b7.1_pipeline_testing/test_data_suite/FS_CV3/ESA_Int_products/Trace_SLIT_A_1600_V84600004001P0000000002101_39530_JLAB88_000001.fits"
        esafile, esafile_log_msgs = auxfunc.get_esafile(esa_files_path,
                                                        raw_data_root_file,
                                                        "FS",
                                                        specifics,
                                                        nid=nid)
        for msg in esafile_log_msgs:
            log_msgs.append(msg)

        # skip the test if the esafile was not found
        if esafile == "ESA file not found":
            msg1 = " * compare_wcs_fs.py is exiting because the corresponding ESA file was not found."
            msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
            print(msg1)
            print(msg2)
            log_msgs.append(msg1)
            log_msgs.append(msg2)
            FINAL_TEST_RESULT = "skip"
            return FINAL_TEST_RESULT, log_msgs
        """
        # comparison of filter, grating and x and y tilt
        gwa_xtil = fits.getval(infile_name, "gwa_xtil", 0)
        gwa_ytil = fits.getval(infile_name, "gwa_ytil", 0)
        esagrat = fits.getval(esafile, "GWA_POS", 0)
        esafilt = fits.getval(esafile, "FWA_POS", 0)
        esa_xtil = fits.getval(esafile, "gwa_xtil", 0)
        esa_ytil = fits.getval(esafile, "gwa_ytil", 0)
        print("pipeline: ")
        print("grating=", grat, " Filter=", filt, " gwa_xtil=", gwa_xtil, " gwa_ytil=", gwa_ytil)
        print("ESA:")
        print("grating=", esagrat, " Filter=", esafilt, " gwa_xtil=", esa_xtil, " gwa_ytil=", esa_ytil)
        """

        # Open the trace in the esafile
        msg = "Using this ESA file: \n" + esafile
        print(msg)
        log_msgs.append(msg)
        with fits.open(esafile) as esahdulist:
            print("* ESA file contents ")
            esahdulist.info()
            esa_slit_id = map_slit_names[esahdulist[0].header['SLITID']]
            # first check is esa_slit == to pipe_slit?
            if pipeslit == esa_slit_id:
                msg = "\n -> Same slit found for pipeline and ESA data: " + pipeslit + "\n"
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of slits for pipeline and ESA data: " + pipeslit, esa_slit_id + "\n"
                print(msg)
                log_msgs.append(msg)

            # Assign variables according to detector
            skipv2v3test = True
            if det == "NRS1":
                try:
                    esa_flux = fits.getdata(esafile, "DATA1")
                    esa_wave = fits.getdata(esafile, "LAMBDA1")
                    esa_slity = fits.getdata(esafile, "SLITY1")
                    esa_msax = fits.getdata(esafile, "MSAX1")
                    esa_msay = fits.getdata(esafile, "MSAY1")
                    pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X1")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y1")
                        skipv2v3test = False
                    except:
                        KeyError
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except:
                    KeyError
                    msg = "This file does not contain data for detector NRS1. Skipping test for this slit."
                    print(msg)
                    log_msgs.append(msg)
                    continue

            if det == "NRS2":
                try:
                    esa_flux = fits.getdata(esafile, "DATA2")
                    esa_wave = fits.getdata(esafile, "LAMBDA2")
                    esa_slity = fits.getdata(esafile, "SLITY2")
                    esa_msax = fits.getdata(esafile, "MSAX2")
                    esa_msay = fits.getdata(esafile, "MSAY2")
                    pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X2")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y2")
                        skipv2v3test = False
                    except:
                        KeyError
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except:
                    KeyError
                    msg1 = "\n * compare_wcs_fs.py is exiting because there are no extensions that match detector NRS2 in the ESA file."
                    msg2 = "   -> The WCS test is now set to skip and no plots will be generated. \n"
                    print(msg1)
                    print(msg2)
                    log_msgs.append(msg1)
                    log_msgs.append(msg2)
                    FINAL_TEST_RESULT = "skip"
                    return FINAL_TEST_RESULT, log_msgs

        # get the WCS object for this particular slit
        wcs_slit = nirspec.nrs_wcs_set_input(img, pipeslit)

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slit)

        # The WCS object attribute bounding_box shows all valid inputs, i.e. the actual area of the data according
        # to the slit. Inputs outside of the bounding_box return NaN values.
        #bbox = wcs_slit.bounding_box
        #print('bounding_box: ', bbox)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        #print("Avalable frames: ", wcs_slit.available_frames)

        if debug:
            # To get specific pixel values use following syntax:
            det2slit = wcs_slit.get_transform('detector', 'slit_frame')
            slitx, slity, lam = det2slit(700, 1080)
            print("slitx: ", slitx)
            print("slity: ", slity)
            print("lambda: ", lam)

        if debug:
            # The number of inputs and outputs in each frame can vary. This can be checked with:
            print('Number on inputs: ', det2slit.n_inputs)
            print('Number on outputs: ', det2slit.n_outputs)

        # Create x, y indices using the Trace WCS
        pipey, pipex = np.mgrid[:esa_wave.shape[0], :esa_wave.shape[1]]
        esax, esay = pyw.all_pix2world(pipex, pipey, 0)

        if det == "NRS2":
            esax = 2049 - esax
            esay = 2049 - esay
            msg = "Flipped ESA data for detector NRS2 comparison with pipeline."
            print(msg)
            log_msgs.append(msg)

        # check if subarray is not FULL FRAME
        subarray = fits.getval(infile_name, "SUBARRAY", 0)

        if "FULL" not in subarray:
            # In subarray coordinates
            # subtract xstart and ystart values in order to get subarray coords instead of full frame
            # wcs_slit.x(y)start are 1-based, turn them to 0-based for extraction
            xstart, ystart = img.meta.subarray.xstart, img.meta.subarray.ystart
            esay = esay - (ystart - 1)
            esax = esax - (xstart - 1)

            #print("img.meta.subarray._instance = ", img.meta.subarray._instance)

            bounding_box = False

            # In full frame coordinates
            #pipey, pipex = np.mgrid[:esa_wave.shape[0], : esa_wave.shape[1]]
            #esax, esay = pyw.all_pix2world(pipex, pipey, 0)
            #sca2world = wcs_slit.get_transform('sca', 'world')
            #pra, pdec, pwave = sca2world(esax - 1, esay - 1)
        else:
            bounding_box = True

        # Compute pipeline RA, DEC, and lambda
        pra, pdec, pwave = wcs_slit(
            esax - 1, esay - 1,
            with_bounding_box=bounding_box)  # => RETURNS: RA, DEC, LAMBDA
        pwave *= 10**-6  # (lam *= 10**-6 to convert to microns)
        """
        # checking that both ESA and pipeline have non NAN values
        no_nansp, no_nanse = [], []
        for vp, ve in zip(pwave, esa_wave):
            if np.nan not in vp:
                print(vp, ve)
                no_nansp.append(vp)
                no_nanse.append(ve)
        print(len(no_nansp), len(no_nanse))
        """

        # calculate and print statistics for slit-y and x relative differences
        tested_quantity = "Wavelength Difference"
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats = rel_diff_pwave_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pwave_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # get the transforms for pipeline slit-y
        det2slit = wcs_slit.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(esax - 1,
                                   esay - 1,
                                   with_bounding_box=bounding_box)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_slity, slity, tested_quantity)
        # calculate and print statistics for slit-y and x absolute differences
        #rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=True)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats = rel_diff_pslity_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pslity_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slit.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(
            esax - 1, esay - 1, with_bounding_box=bounding_box
        )  # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats = reldiffpmsax_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsax_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats = reldiffpmsay_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsay_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # V2 and V3
        if not skipv2v3test:
            detector2v2v3 = wcs_slit.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(
                esax - 1, esay - 1, with_bounding_box=bounding_box
            )  # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            # converting to degrees to compare with ESA, pipeline is in arcsec
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats = reldiffpv2_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv2_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

            tested_quantity = "V3 difference"
            # converting to degrees to compare with ESA
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats = reldiffpv3_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv3_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            basenameinfile_name = os.path.basename(infile_name)
            main_title = filt + "   " + grat + "   SLIT=" + pipeslit + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate number
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{ESA}) / \lambda_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    pipeslit + "_" + det + "_rel_wave_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{ESA}$)/slit_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit position."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    pipeslit + "_" + det + "_rel_slitY_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{ESA}$)/MSA_x$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    pipeslit + "_" + det + "_rel_MSAx_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{ESA}$)/MSA_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                plt_name = infile_name.replace(
                    basenameinfile_name,
                    pipeslit + "_" + det + "_rel_MSAy_diffs.pdf")
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{ESA}$)/V2$_{ESA}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + "_rel_V2_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 notnan_reldiffpv2_stats,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{ESA}$)/V3$_{ESA}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + "_rel_V3_diffs.pdf")
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 notnan_reldiffpv3,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testdir in total_test_result.items():
        for t, tr in testdir.items():
            if tr == "FAILED":
                FINAL_TEST_RESULT = "FAILED"
                msg = "\n * The test of " + t + " for slit " + sl + " FAILED."
                print(msg)
                log_msgs.append(msg)
            else:
                FINAL_TEST_RESULT = "PASSED"
                msg = "\n * The test of " + t + " for slit " + sl + " PASSED."
                print(msg)
                log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
    print(msg)
    log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
Exemple #30
0
def compare_wcs(infile_name,
                truth_file=None,
                esa_files_path=None,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                raw_data_root_file=None,
                output_directory=None,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the truth files or the ESA intermediary files (to create new truth files).

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        truth_file: str, full path to the 'truth' (or benchmark) file to compare to
        esa_files_path: str or None, path to file all the ESA intermediary files
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and truth file
        raw_data_root_file: None or string, name of the raw file that produced the _uncal.fits file for caldetector1
        output_directory: None or string, path to the output_directory where to save the plots
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements captured in this variable

    """

    log_msgs = []

    if truth_file is not None:
        # get the model from the "truth" (or comparison) file
        truth_hdul = fits.open(truth_file)
        print("Information from the 'truth' (or comparison) file ")
        print(truth_hdul.info())
        truth_hdul.close()
        truth_img = datamodels.ImageModel(truth_file)
        # get the open slits in the truth file
        open_slits_truth = truth_img.meta.wcs.get_transform(
            'gwa', 'slit_frame').slits
        # open the datamodel
        truth_model = datamodels.MultiSlitModel(truth_file)
        # determine to what are we comparing to
        esa_files_path = None
        print("Comparing to ST 'truth' file.")
    else:
        print("Comparing to ESA data")

    # get grating and filter info from the rate file header
    if isinstance(infile_name, str):
        # get the datamodel from the assign_wcs output file
        img = datamodels.ImageModel(infile_name)
        msg = 'infile_name=' + infile_name
        print(msg)
        log_msgs.append(msg)
        basenameinfile_name = os.path.basename(infile_name)
    else:
        img = infile_name
        basenameinfile_name = ""

    det = img.meta.instrument.detector
    lamp = img.meta.instrument.lamp_state
    grat = img.meta.instrument.grating
    filt = img.meta.instrument.filter
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + \
          filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)
    print("GWA_XTILT: {}".format(img.meta.instrument.gwa_xtilt))
    print("GWA_YTILT: {}".format(img.meta.instrument.gwa_ytilt))
    print("GWA_TTILT: {}".format(img.meta.instrument.gwa_tilt))

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    if esa_files_path is not None:
        # mapping the ESA slit names to pipeline names
        map_slit_names = {
            'SLIT_A_1600': 'S1600A1',
            'SLIT_A_200_1': 'S200A1',
            'SLIT_A_200_2': 'S200A2',
            'SLIT_A_400': 'S400A1',
            'SLIT_B_200': 'S200B1',
        }

    # To get the open and projected on the detector slits of the pipeline processed file
    open_slits = img.meta.wcs.get_transform('gwa', 'slit_frame').slits
    for opslit in open_slits:
        pipeslit = opslit.name
        msg = "\nWorking with slit: " + pipeslit
        print(msg)
        log_msgs.append(msg)

        compare_to_esa_data = False
        if esa_files_path is not None:
            compare_to_esa_data = True

            # Get the ESA trace
            if raw_data_root_file is None:
                _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
                    infile_name)
            specifics = [pipeslit]

            # check if ESA data is not in the regular directory tree, these files are exceptions
            NIDs = ["30055", "30055", "30205", "30133", "30133"]
            special_cutout_files = [
                "NRSSMOS-MOD-G1H-02-5344031756_1_491_SE_2015-12-10T03h25m56.fits",
                "NRSSMOS-MOD-G1H-02-5344031756_1_492_SE_2015-12-10T03h25m56.fits",
                "NRSSMOS-MOD-G2M-01-5344191938_1_491_SE_2015-12-10T19h29m26.fits",
                "NRSSMOS-MOD-G3H-02-5344120942_1_491_SE_2015-12-10T12h18m25.fits",
                "NRSSMOS-MOD-G3H-02-5344120942_1_492_SE_2015-12-10T12h18m25.fits"
            ]
            if raw_data_root_file in special_cutout_files:
                nid = NIDs[special_cutout_files.index(raw_data_root_file)]
                msg = "Using NID = " + nid
                print(msg)
                log_msgs.append(msg)
            else:
                nid = None

            esafile, esafile_log_msgs = auxfunc.get_esafile(esa_files_path,
                                                            raw_data_root_file,
                                                            "FS",
                                                            specifics,
                                                            nid=nid)
            for msg in esafile_log_msgs:
                log_msgs.append(msg)

            # skip the test if the esafile was not found
            if esafile == "ESA file not found":
                msg1 = " * compare_wcs_fs.py is exiting because the corresponding ESA file was not found."
                msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
                print(msg1)
                print(msg2)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
                FINAL_TEST_RESULT = "skip"
                return FINAL_TEST_RESULT, log_msgs
            """
            # comparison of filter, grating and x and y tilt
            gwa_xtil = fits.getval(infile_name, "gwa_xtil", 0)
            gwa_ytil = fits.getval(infile_name, "gwa_ytil", 0)
            esagrat = fits.getval(esafile, "GWA_POS", 0)
            esafilt = fits.getval(esafile, "FWA_POS", 0)
            esa_xtil = fits.getval(esafile, "gwa_xtil", 0)
            esa_ytil = fits.getval(esafile, "gwa_ytil", 0)
            print("pipeline: ")
            print("grating=", grat, " Filter=", filt, " gwa_xtil=", gwa_xtil, " gwa_ytil=", gwa_ytil)
            print("ESA:")
            print("grating=", esagrat, " Filter=", esafilt, " gwa_xtil=", esa_xtil, " gwa_ytil=", esa_ytil)
            """

            # Open the trace in the esafile
            msg = "Using this ESA file: \n" + esafile
            print(msg)
            log_msgs.append(msg)
            with fits.open(esafile) as esahdulist:
                print("* ESA file contents ")
                esahdulist.info()
                esa_slit_id = map_slit_names[esahdulist[0].header['SLITID']]
                # first check is esa_slit == to pipe_slit?
                if pipeslit == esa_slit_id:
                    msg = "\n -> Same slit found for pipeline and ESA data: " + pipeslit + "\n"
                    print(msg)
                    log_msgs.append(msg)
                else:
                    msg = "\n -> Missmatch of slits for pipeline and ESA data: " + pipeslit, esa_slit_id + "\n"
                    print(msg)
                    log_msgs.append(msg)

                # Assign variables according to detector
                skipv2v3test = True
                if det == "NRS1":
                    try:
                        truth_flux = fits.getdata(esafile, "DATA1")
                        truth_wave = fits.getdata(esafile, "LAMBDA1")
                        truth_slity = fits.getdata(esafile, "SLITY1")
                        truth_msax = fits.getdata(esafile, "MSAX1")
                        truth_msay = fits.getdata(esafile, "MSAY1")
                        pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                        try:
                            truth_v2 = fits.getdata(esafile, "V2V3X1")
                            truth_v3 = fits.getdata(esafile, "V2V3Y1")
                            skipv2v3test = False
                        except KeyError:
                            msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding " \
                                  "extensions."
                            print(msg)
                            log_msgs.append(msg)
                    except KeyError:
                        msg = "This file does not contain data for detector NRS1. Skipping test for this slit."
                        print(msg)
                        log_msgs.append(msg)
                        continue

                if det == "NRS2":
                    try:
                        truth_flux = fits.getdata(esafile, "DATA2")
                        truth_wave = fits.getdata(esafile, "LAMBDA2")
                        truth_slity = fits.getdata(esafile, "SLITY2")
                        truth_msax = fits.getdata(esafile, "MSAX2")
                        truth_msay = fits.getdata(esafile, "MSAY2")
                        pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                        try:
                            truth_v2 = fits.getdata(esafile, "V2V3X2")
                            truth_v3 = fits.getdata(esafile, "V2V3Y2")
                            skipv2v3test = False
                        except KeyError:
                            msg = "Skipping tests for V2 and V3 because ESA file does not contain " \
                                  "corresponding extensions."
                            print(msg)
                            log_msgs.append(msg)
                    except KeyError:
                        msg1 = "\n * compare_wcs_fs.py is exiting because there are no extensions that match " \
                               "detector NRS2 in the ESA file."
                        msg2 = "   -> The WCS test is now set to skip and no plots will be generated. \n"
                        print(msg1)
                        print(msg2)
                        log_msgs.append(msg1)
                        log_msgs.append(msg2)
                        FINAL_TEST_RESULT = "skip"
                        return FINAL_TEST_RESULT, log_msgs

            # Create x, y indices using the trace WCS from ESA
            pipey, pipex = np.mgrid[:truth_wave.shape[0], :truth_wave.shape[1]]
            esax, esay = pyw.all_pix2world(pipex, pipey, 0)

            if det == "NRS2":
                esax = 2049 - esax
                esay = 2049 - esay
                msg = "Flipped ESA data for detector NRS2 comparison with pipeline."
                print(msg)
                log_msgs.append(msg)

            # remove 1 to start from 0
            truth_x, truth_y = esax - 1, esay - 1

        # In case we are NOT comparing to ESA data
        if not compare_to_esa_data:
            # determine if the current open slit is also open in the truth file
            continue_with_wcs_test = False
            for truth_open_slit in open_slits_truth:
                if truth_open_slit.name == pipeslit:
                    continue_with_wcs_test = True
                    break

            if not continue_with_wcs_test:
                msg1 = "\n * Script compare_wcs_fs.py is exiting because open slit "+pipeslit+" is not open in " \
                                                                                              "truth file."
                msg2 = "   -> The WCS test is now set to FAILED and no plots will be generated. \n"
                print(msg1)
                print(msg2)
                log_msgs.append(msg1)
                log_msgs.append(msg2)
                FINAL_TEST_RESULT = "FAILED"
                return FINAL_TEST_RESULT, log_msgs

            # get the WCS object for this particular truth slit
            slit_wcs = nirspec.nrs_wcs_set_input(truth_model, pipeslit)
            truth_x, truth_y = wcstools.grid_from_bounding_box(
                slit_wcs.bounding_box, step=(1, 1), center=True)
            truth_ra, truth_dec, truth_wave = slit_wcs(
                truth_x, truth_y)  # wave is in microns
            truth_wave *= 10**-6  # (lam *= 10**-6 to convert to microns)
            # get the truths to compare to
            truth_det2slit = slit_wcs.get_transform('detector', 'slit_frame')
            truth_slitx, truth_slity, _ = truth_det2slit(truth_x, truth_y)
            truth_det2slit = slit_wcs.get_transform("detector", "msa_frame")
            truth_msax, truth_msay, _ = truth_det2slit(truth_x, truth_y)
            truth_det2slit = slit_wcs.get_transform("detector", "v2v3")
            truth_v2, truth_v3, _ = truth_det2slit(truth_x, truth_y)
            skipv2v3test = False

        # get the WCS object for this particular slit
        wcs_slit = nirspec.nrs_wcs_set_input(img, pipeslit)

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slit)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        available_frames = wcs_slit.available_frames
        print("Avalable frames: ", available_frames)

        if debug:
            # To get specific pixel values use following syntax:
            det2slit = wcs_slit.get_transform('detector', 'slit_frame')
            slitx, slity, lam = det2slit(700, 1080)
            print("slitx: ", slitx)
            print("slity: ", slity)
            print("lambda: ", lam)

            # The number of inputs and outputs in each frame can vary. This can be checked with:
            print('Number on inputs: ', det2slit.n_inputs)
            print('Number on outputs: ', det2slit.n_outputs)

        # check if subarray is not FULL FRAME
        subarray = img.meta.subarray
        if "FULL" not in subarray:
            # In subarray coordinates
            # subtract xstart and ystart values in order to get subarray coords instead of full frame
            # wcs_slit.x(y)start are 1-based, turn them to 0-based for extraction
            xstart, ystart = img.meta.subarray.xstart, img.meta.subarray.ystart
            bounding_box = False
        else:
            bounding_box = True

        pra, pdec, pwave = wcs_slit(
            truth_x, truth_y,
            with_bounding_box=bounding_box)  # => RETURNS: RA, DEC, LAMBDA
        pwave *= 10**-6  # (lam *= 10**-6 to convert to microns)

        # calculate and print statistics for slit-y and x relative differences
        tested_quantity = "Wavelength Difference"
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, print_stats = rel_diff_pwave_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pwave_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # get the transforms for pipeline slit-y
        det2slit = wcs_slit.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(truth_x,
                                   truth_y,
                                   with_bounding_box=bounding_box)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff,
            truth_slity,
            truth_slity,
            slity,
            tested_quantity,
            absolute=False)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, print_stats = rel_diff_pslity_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_rel_diff_pslity_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slit.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(truth_x,
                                       truth_y,
                                       with_bounding_box=bounding_box)
        # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, print_stats = reldiffpmsax_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsax_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, truth_slity, truth_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, print_stats = reldiffpmsay_data
        for msg in print_stats:
            log_msgs.append(msg)
        test_result = auxfunc.does_median_pass_tes(
            notnan_reldiffpmsay_stats[1], threshold_diff)
        msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result[pipeslit] = {tested_quantity: test_result}

        # V2 and V3
        if not skipv2v3test and 'v2v3' in available_frames:
            detector2v2v3 = wcs_slit.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(truth_x,
                                        truth_y,
                                        with_bounding_box=bounding_box)
            # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, truth_slity, truth_v2, pv2, tested_quantity)
            # converting to degrees to compare with truth, pipeline is in arcsec
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with truth file"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, truth_slity, truth_v2, pv2,
                    tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, print_stats = reldiffpv2_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv2_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

            tested_quantity = "V3 difference"
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, truth_slity, truth_v3, pv3, tested_quantity)
            # converting to degrees to compare with truth, pipeline is in arcsec
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with truth file"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, truth_slity, truth_v3, pv3,
                    tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, print_stats = reldiffpv3_data
            for msg in print_stats:
                log_msgs.append(msg)
            test_result = auxfunc.does_median_pass_tes(
                notnan_reldiffpv3_stats[1], threshold_diff)
            msg = "\n * Result of the test for " + tested_quantity + ":  " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result[pipeslit] = {tested_quantity: test_result}

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            main_title = filt + "   " + grat + "   SLIT=" + pipeslit + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate number
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{truth}) / \lambda_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_wave_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = os.path.join(
                            os.getcwd(),
                            pipeslit + "_" + det + specific_plt_name)
                        print(
                            "No output_directory was provided. Figures will be saved in current working directory:"
                        )
                        print(plt_name + "\n")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{truth}$)/slit_y$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit position."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_slitY_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{truth}$)/MSA_x$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAx_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{truth}$)/MSA_y$_{truth}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAy_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pipeslit + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test and 'v2v3' in available_frames:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{truth}$)/V2$_{truth}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V2_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pipeslit + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 notnan_reldiffpv2_stats,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{truth}$)/V3$_{truth}$", "N"
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V3_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pipeslit + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pipeslit + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 notnan_reldiffpv3,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testdir in total_test_result.items():
        for t, tr in testdir.items():
            if tr == "FAILED":
                FINAL_TEST_RESULT = "FAILED"
                msg = "\n * The test of " + t + " for slit " + sl + " FAILED."
                print(msg)
                log_msgs.append(msg)
            else:
                FINAL_TEST_RESULT = "PASSED"
                msg = "\n * The test of " + t + " for slit " + sl + " PASSED."
                print(msg)
                log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
    print(msg)
    log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
Exemple #31
0
def _corrections_for_ifu(data, pathloss, source_type):
    """Calculate the correction arrasy for MOS slit

    Parameters
    ----------
    data : jwst.datamodels.SlitModel
        The data being operated on.

    pathloss : jwst.datamodels.DataModel
        The pathloss reference data

    source_type : str or None
        Force processing using the specified source type.

    Returns
    -------
    correction : jwst.datamodels.SlitModel
        The correction arrays
    """

    # IFU targets are always inside slit
    # Get centering
    xcenter, ycenter = get_center(data.meta.exposure.type, None)

    # Calculate the 1-d wavelength and pathloss vectors for the source position
    aperture = pathloss.apertures[0]
    (wavelength_pointsource, pathloss_pointsource_vector,
     dummy) = calculate_pathloss_vector(aperture.pointsource_data,
                                        aperture.pointsource_wcs, xcenter,
                                        ycenter)
    (wavelength_uniformsource, pathloss_uniform_vector,
     dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                        aperture.uniform_wcs, xcenter, ycenter)
    # Wavelengths in the reference file are in meters;
    # need them to be in microns
    wavelength_pointsource *= 1.0e6
    wavelength_uniformsource *= 1.0e6

    # Create the 2-d wavelength arrays, initialize with NaNs
    wavelength_array = np.zeros(data.shape, dtype=np.float32)
    wavelength_array.fill(np.nan)
    for slice in NIRSPEC_IFU_SLICES:
        slice_wcs = nirspec.nrs_wcs_set_input(data, slice)
        x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
        ra, dec, wavelength = slice_wcs(x, y)
        valid = ~np.isnan(wavelength)
        x = x[valid]
        y = y[valid]
        wavelength_array[y.astype(int), x.astype(int)] = wavelength[valid]

    # Compute the point source pathloss 2D correction
    pathloss_2d_ps = interpolate_onto_grid(wavelength_array,
                                           wavelength_pointsource,
                                           pathloss_pointsource_vector)

    # Compute the uniform source pathloss 2D correction
    pathloss_2d_un = interpolate_onto_grid(wavelength_array,
                                           wavelength_uniformsource,
                                           pathloss_uniform_vector)

    # Use the appropriate correction for the source type
    if is_pointsource(source_type or data.meta.target.source_type):
        pathloss_2d = pathloss_2d_ps
    else:
        pathloss_2d = pathloss_2d_un

    # Save the corrections. The `data` portion is the correction used.
    # The individual ones will be saved in the respective attributes.
    correction = type(data)(data=pathloss_2d)
    correction.pathloss_point = pathloss_2d_ps
    correction.pathloss_uniform = pathloss_2d_un
    correction.wavelength = wavelength_array

    return correction
Exemple #32
0
def test_functional_fs_msa(mode):
    #     """
    #     Compare Nirspec instrument model with IDT model for FS and MSA.
    #     """
    if mode == 'fs':
        model_file = 'fixed_slits_functional_ESA_v4_20180618.txt'
        hdul = create_nirspec_fs_file(grating='G395H', filter='F290LP')
        im = datamodels.ImageModel(hdul)
        refs = create_reference_files(im)
        pipeline = nirspec.create_pipeline(im,
                                           refs,
                                           slit_y_range=[-0.55, 0.55])
        w = wcs.WCS(pipeline)
        im.meta.wcs = w
        # Use slit S200A1
        slit_wcs = nirspec.nrs_wcs_set_input(im, 'S200A1')

    if mode == 'msa':
        model_file = 'msa_functional_ESA_v2_20180620.txt'
        hdul = create_nirspec_mos_file(grating='G395H', filt='F290LP')
        im = datamodels.ImageModel(hdul)
        refs = create_reference_files(im)
        slit = trmodels.Slit(name=1,
                             shutter_id=4699,
                             xcen=319,
                             ycen=13,
                             ymin=-0.55000000000000004,
                             ymax=0.55000000000000004,
                             quadrant=3,
                             source_id=1,
                             shutter_state='x',
                             source_name='lamp',
                             source_alias='foo',
                             stellarity=100.0,
                             source_xpos=-0.5,
                             source_ypos=0.5)
        open_slits = [slit]
        pipeline = nirspec.slitlets_wcs(im, refs, open_slits)
        w = wcs.WCS(pipeline)
        im.meta.wcs = w
        slit_wcs = nirspec.nrs_wcs_set_input(im, 1)

    ins_file = get_file_path(model_file)
    ins_tab = table.Table.read(ins_file, format='ascii')

    # Setup the test
    slitx = [0] * 5
    slity = [-.5, -.25, 0, .25, .5]
    lam = np.array([2.9, 3.39, 3.88, 4.37, 5]) * 10**-6

    # Slit to MSA absolute
    slit2msa = slit_wcs.get_transform('slit_frame', 'msa_frame')
    msax, msay, _ = slit2msa(slitx, slity, lam)

    assert_allclose(slitx, ins_tab['xslitpos'])
    assert_allclose(slity, ins_tab['yslitpos'])
    assert_allclose(msax, ins_tab['xmsapos'])
    assert_allclose(msay, ins_tab['ymaspos'])

    # Coordinates at Collimator exit
    # Applies the Collimator forward transform to MSa absolute coordinates
    with datamodels.open(refs['collimator']) as col:
        colx, coly = col.model.inverse(msax, msay)
    assert_allclose(colx, ins_tab['xcoll'])
    assert_allclose(coly, ins_tab['ycoll'])

    # After applying direcitonal cosines
    dircos = trmodels.Unitless2DirCos()
    xcolDircosi, ycolDircosi, z = dircos(colx, coly)
    assert_allclose(xcolDircosi, ins_tab['xcolDirCosi'])
    assert_allclose(ycolDircosi, ins_tab['ycolDirCosi'])

    # MSA to GWA entrance
    # This runs the Collimator forward, Unitless to Directional cosine, and
    # 3D Rotation. It uses the corrected GWA tilt value
    with datamodels.DisperserModel(refs['disperser']) as disp:
        disperser = nirspec.correct_tilt(disp, im.meta.instrument.gwa_xtilt,
                                         im.meta.instrument.gwa_ytilt)
    collimator2gwa = nirspec.collimator_to_gwa(refs, disperser)
    x_gwa_in, y_gwa_in, z_gwa_in = collimator2gwa(msax, msay)
    assert_allclose(x_gwa_in, ins_tab['xdispIn'])
    assert_allclose(y_gwa_in, ins_tab['ydispIn'])

    # Slit to GWA out
    slit2gwa = slit_wcs.get_transform('slit_frame', 'gwa')
    x_gwa_out, y_gwa_out, z_gwa_out = slit2gwa(slitx, slity, lam)
    assert_allclose(x_gwa_out, ins_tab['xdispLaw'])
    assert_allclose(y_gwa_out, ins_tab['ydispLaw'])

    # CAMERA entrance (assuming direction is from sky to detector)
    angles = [
        disperser['theta_x'], disperser['theta_y'], disperser['theta_z'],
        disperser['tilt_y']
    ]
    rotation = trmodels.Rotation3DToGWA(angles,
                                        axes_order="xyzy",
                                        name='rotation')
    dircos2unitless = trmodels.DirCos2Unitless()
    gwa2cam = rotation.inverse | dircos2unitless
    x_camera_entrance, y_camera_entrance = gwa2cam(x_gwa_out, y_gwa_out,
                                                   z_gwa_out)
    assert_allclose(x_camera_entrance, ins_tab['xcamCosi'])
    assert_allclose(y_camera_entrance, ins_tab['ycamCosi'])

    # at FPA
    with datamodels.CameraModel(refs['camera']) as camera:
        x_fpa, y_fpa = camera.model.inverse(x_camera_entrance,
                                            y_camera_entrance)
    assert_allclose(x_fpa, ins_tab['xfpapos'])
    assert_allclose(y_fpa, ins_tab['yfpapos'])

    # at SCA These are 0-based , the IDT results are 1-based
    slit2sca = slit_wcs.get_transform('slit_frame', 'sca')
    x_sca_nrs1, y_sca_nrs1 = slit2sca(slitx, slity, lam)
    # At NRS2
    with datamodels.FPAModel(refs['fpa']) as fpa:
        x_sca_nrs2, y_sca_nrs2 = fpa.nrs2_model.inverse(x_fpa, y_fpa)
    # expect 1 pix difference
    wvlns_on_nrs1 = slice(2)
    wvlns_on_nrs2 = slice(2, 4)
    assert_allclose(x_sca_nrs1[wvlns_on_nrs1] + 1, ins_tab['i'][wvlns_on_nrs1])
    assert_allclose(y_sca_nrs1[wvlns_on_nrs1] + 1, ins_tab['j'][wvlns_on_nrs1])
    assert_allclose(x_sca_nrs2[wvlns_on_nrs2] + 1, ins_tab['i'][wvlns_on_nrs2])
    assert_allclose(y_sca_nrs2[wvlns_on_nrs2] + 1, ins_tab['j'][wvlns_on_nrs2])

    # at oteip
    slit2oteip = slit_wcs.get_transform('slit_frame', 'oteip')
    x_oteip, y_oteip, _ = slit2oteip(slitx, slity, lam)
    assert_allclose(x_oteip, ins_tab['xOTEIP'])
    assert_allclose(y_oteip, ins_tab['yOTEIP'])

    # at v2, v3 [in arcsec]
    slit2v23 = slit_wcs.get_transform('slit_frame', 'v2v3')
    v2, v3, _ = slit2v23(slitx, slity, lam)
    v2 /= 3600
    v3 /= 3600
    assert_allclose(v2, ins_tab['xV2V3'])
    assert_allclose(v3, ins_tab['yV2V3'])
def compare_wcs(infile_name,
                esa_files_path=None,
                show_figs=True,
                save_figs=False,
                threshold_diff=1.0e-7,
                raw_data_root_file=None,
                output_directory=None,
                debug=False):
    """
    This function does the WCS comparison from the world coordinates calculated using the pipeline
    data model with the ESA intermediary files.

    Args:
        infile_name: str, name of the output fits file from the assign_wcs step (with full path)
        esa_files_path: str, full path of where to find all ESA intermediary products to make comparisons for the tests
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots or not
        threshold_diff: float, threshold difference between pipeline output and ESA file
        raw_data_root_file: None or string, name of the raw file that produced the _uncal.fits file for caldetector1
        output_directory: None or string, path to the output_directory where to save the plots
        debug: boolean, if true a series of print statements will show on-screen

    Returns:
        - plots, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold
        - log_msgs: list, all print statements are captured in this variable

    """

    log_msgs = []

    # get grating and filter info from the rate file header
    if isinstance(infile_name, str):
        # get the datamodel from the assign_wcs output file
        img = datamodels.ImageModel(infile_name)
        msg = 'infile_name=' + infile_name
        print(msg)
        log_msgs.append(msg)
        basenameinfile_name = os.path.basename(infile_name)
    else:
        img = infile_name
        basenameinfile_name = ""

    det = img.meta.instrument.detector  # fits.getval(infile_name, "DETECTOR", 0)
    lamp = img.meta.instrument.lamp_state  # fits.getval(infile_name, "LAMP", 0)
    grat = img.meta.instrument.grating  # fits.getval(infile_name, "GRATING", 0)
    filt = img.meta.instrument.filter  # fits.getval(infile_name, "FILTER", 0)
    msg = "from assign_wcs file  -->     Detector: " + det + "   Grating: " + grat + "   Filter: " + \
          filt + "   Lamp: " + lamp
    print(msg)
    log_msgs.append(msg)

    # loop over the slices: 0 - 29
    img = datamodels.ImageModel(infile_name)
    slice_list = img.meta.wcs.get_transform('gwa', 'slit_frame').slits

    # dictionary to record if each test passed or not
    total_test_result = OrderedDict()

    # loop over the slices
    for indiv_slice in slice_list:
        if int(indiv_slice) < 10:
            pslice = "0" + repr(indiv_slice)
        else:
            pslice = repr(indiv_slice)
        msg = "\n Working with slice: " + pslice
        print(msg)
        log_msgs.append(msg)

        # Get the ESA trace
        if raw_data_root_file is None:
            _, raw_data_root_file = auxfunc.get_modeused_and_rawdatrt_PTT_cfg_file(
                infile_name)
        specifics = [pslice]
        esafile = auxfunc.get_esafile(esa_files_path, raw_data_root_file,
                                      "IFU", specifics)[0]

        # skip the test if the esafile was not found
        if "ESA file not found" in esafile:
            msg1 = " * compare_wcs_ifu.py is exiting because the corresponding ESA file was not found."
            msg2 = "   -> The WCS test is now set to skip and no plots will be generated. "
            print(msg1)
            print(msg2)
            log_msgs.append(msg1)
            log_msgs.append(msg2)
            FINAL_TEST_RESULT = "skip"
            return FINAL_TEST_RESULT, log_msgs

        # Open the trace in the esafile
        msg = "Using this ESA file: \n" + str(esafile)
        print(msg)
        log_msgs.append(msg)
        with fits.open(esafile) as esahdulist:
            print("* ESA file contents ")
            esahdulist.info()
            esa_slice_id = esahdulist[0].header['SLICEID']
            # first check is esa_slice == to pipe_slice?
            if indiv_slice == esa_slice_id:
                msg = "\n -> Same slice found for pipeline and ESA data: " + repr(
                    indiv_slice) + "\n"
                print(msg)
                log_msgs.append(msg)
            else:
                msg = "\n -> Missmatch of slices for pipeline and ESA data: " + repr(
                    indiv_slice) + esa_slice_id + "\n"
                print(msg)
                log_msgs.append(msg)

            # Assign variables according to detector
            skipv2v3test = True
            if det == "NRS1":
                esa_flux = fits.getdata(esafile, "DATA1")
                esa_wave = fits.getdata(esafile, "LAMBDA1")
                esa_slity = fits.getdata(esafile, "SLITY1")
                esa_msax = fits.getdata(esafile, "MSAX1")
                esa_msay = fits.getdata(esafile, "MSAY1")
                pyw = wcs.WCS(esahdulist['LAMBDA1'].header)
                try:
                    esa_v2v3x = fits.getdata(esafile, "V2V3X1")
                    esa_v2v3y = fits.getdata(esafile, "V2V3Y1")
                    skipv2v3test = False
                except KeyError:
                    msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                    print(msg)
                    log_msgs.append(msg)
            if det == "NRS2":
                try:
                    esa_flux = fits.getdata(esafile, "DATA2")
                    esa_wave = fits.getdata(esafile, "LAMBDA2")
                    esa_slity = fits.getdata(esafile, "SLITY2")
                    esa_msax = fits.getdata(esafile, "MSAX2")
                    esa_msay = fits.getdata(esafile, "MSAY2")
                    pyw = wcs.WCS(esahdulist['LAMBDA2'].header)
                    msg = "using NRS2 extensions"
                    print(msg)
                    log_msgs.append(msg)
                    try:
                        esa_v2v3x = fits.getdata(esafile, "V2V3X2")
                        esa_v2v3y = fits.getdata(esafile, "V2V3Y2")
                        skipv2v3test = False
                    except KeyError:
                        msg = "Skipping tests for V2 and V3 because ESA file does not contain corresponding extensions."
                        print(msg)
                        log_msgs.append(msg)
                except KeyError:
                    msg1 = "\n * compare_wcs_ifu.py is exiting because there are no extensions that match detector NRS2 in the ESA file."
                    msg2 = "   -> The WCS test is now set to skip and no plots will be generated. \n"
                    print(msg1)
                    print(msg2)
                    log_msgs.append(msg1)
                    log_msgs.append(msg2)
                    FINAL_TEST_RESULT = "skip"
                    return FINAL_TEST_RESULT, log_msgs

        # get the WCS object for this particular slit
        wcs_slice = nirspec.nrs_wcs_set_input(img, indiv_slice)

        # if we want to print all available transforms, uncomment line below
        #print(wcs_slice)

        # The WCS object attribute bounding_box shows all valid inputs, i.e. the actual area of the data according
        # to the slice. Inputs outside of the bounding_box return NaN values.
        #bbox = wcs_slice.bounding_box
        #print('wcs_slice.bounding_box: ', wcs_slice.bounding_box)

        # In different observing modes the WCS may have different coordinate frames. To see available frames
        # uncomment line below.
        #print("Avalable frames: ", wcs_slice.available_frames)

        if debug:
            # To get specific pixel values use following syntax:
            det2slit = wcs_slice.get_transform('detector', 'slit_frame')
            slitx, slity, lam = det2slit(700, 1080)
            print("slitx: ", slitx)
            print("slity: ", slity)
            print("lambda: ", lam)

            # The number of inputs and outputs in each frame can vary. This can be checked with:
            print('Number on inputs: ', det2slit.n_inputs)
            print('Number on outputs: ', det2slit.n_outputs)

        # Create x, y indices using the Trace WCS
        pipey, pipex = np.mgrid[:esa_wave.shape[0], :esa_wave.shape[1]]
        esax, esay = pyw.all_pix2world(pipex, pipey, 0)
        # need to account for different detector orientation for NRS2
        if det == "NRS2":
            esax = 2049 - esax
            esay = 2049 - esay
        #print( "x,y: "+repr(esax-1)+repr(esay-1) )

        # Compute pipeline RA, DEC, and lambda
        pra, pdec, pwave = wcs_slice(esax - 1, esay - 1)
        # => RETURNS: RA, DEC, LAMBDA (lam *= 10**-6 to convert to microns)
        pwave *= 10**-6
        #print( "wavelengths: "+repr(pwave) )

        # calculate and print statistics for slit-y and x relative differences
        tested_quantity = "Wavelength Difference"
        #print(" ESA wavelength: ", esa_wave)
        #print(" Pipeline wavelength: ", pwave)
        rel_diff_pwave_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_wave, pwave, tested_quantity)
        rel_diff_pwave_img, notnan_rel_diff_pwave, notnan_rel_diff_pwave_stats, stats_print_statements = rel_diff_pwave_data
        for msg in stats_print_statements:
            print(msg)
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pwave_stats[1],
                                              threshold_diff)
        total_test_result["slice" + pslice] = {tested_quantity: result}

        # get the transforms for pipeline slit-y
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(esax - 1, esay - 1)
        tested_quantity = "Slit-Y Difference"
        # calculate and print statistics for slit-y and x relative differences
        rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_slity, slity, tested_quantity)
        # calculate and print statistics for slit-y and x absolute differences
        #rel_diff_pslity_data = auxfunc.get_reldiffarr_and_stats(threshold_diff, esa_slity, esa_slity, slity, tested_quantity, abs=True)
        rel_diff_pslity_img, notnan_rel_diff_pslity, notnan_rel_diff_pslity_stats, stats_print_statements = rel_diff_pslity_data
        for msg in stats_print_statements:
            print(msg)
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_rel_diff_pslity_stats[1],
                                              threshold_diff)
        total_test_result["slice" + pslice] = {tested_quantity: result}

        # do the same for MSA x, y and V2, V3
        detector2msa = wcs_slice.get_transform("detector", "msa_frame")
        pmsax, pmsay, _ = detector2msa(
            esax - 1, esay - 1
        )  # => RETURNS: msaX, msaY, LAMBDA (lam *= 10**-6 to convert to microns)
        # MSA-x
        tested_quantity = "MSA_X Difference"
        reldiffpmsax_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msax, pmsax, tested_quantity)
        reldiffpmsax_img, notnan_reldiffpmsax, notnan_reldiffpmsax_stats, stats_print_statements = reldiffpmsax_data
        for msg in stats_print_statements:
            print(msg)
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsax_stats[1],
                                              threshold_diff)
        total_test_result["slice" + pslice] = {tested_quantity: result}
        # MSA-y
        tested_quantity = "MSA_Y Difference"
        reldiffpmsay_data = auxfunc.get_reldiffarr_and_stats(
            threshold_diff, esa_slity, esa_msay, pmsay, tested_quantity)
        reldiffpmsay_img, notnan_reldiffpmsay, notnan_reldiffpmsay_stats, stats_print_statements = reldiffpmsay_data
        for msg in stats_print_statements:
            print(msg)
            log_msgs.append(msg)
        result = auxfunc.does_median_pass_tes(notnan_reldiffpmsay_stats[1],
                                              threshold_diff)
        total_test_result["slice" + pslice] = {tested_quantity: result}

        # V2 and V3
        if not skipv2v3test:
            detector2v2v3 = wcs_slice.get_transform("detector", "v2v3")
            pv2, pv3, _ = detector2v2v3(esax - 1, esay - 1)
            # => RETURNS: v2, v3, LAMBDA (lam *= 10**-6 to convert to microns)
            tested_quantity = "V2 difference"
            # converting to degrees to compare with ESA
            reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            if reldiffpv2_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv2 = pv2 / 3600.
                reldiffpv2_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3x, pv2, tested_quantity)
            reldiffpv2_img, notnan_reldiffpv2, notnan_reldiffpv2_stats, stats_print_statements = reldiffpv2_data
            for msg in stats_print_statements:
                print(msg)
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv2_stats[1],
                                                  threshold_diff)
            total_test_result["slice" + pslice] = {tested_quantity: result}
            tested_quantity = "V3 difference"
            # converting to degrees to compare with ESA
            reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            if reldiffpv3_data[-2][0] > 0.0:
                print(
                    "\nConverting pipeline results to degrees to compare with ESA"
                )
                pv3 = pv3 / 3600.
                reldiffpv3_data = auxfunc.get_reldiffarr_and_stats(
                    threshold_diff, esa_slity, esa_v2v3y, pv3, tested_quantity)
            reldiffpv3_img, notnan_reldiffpv3, notnan_reldiffpv3_stats, stats_print_statements = reldiffpv3_data
            for msg in stats_print_statements:
                print(msg)
                log_msgs.append(msg)
            result = auxfunc.does_median_pass_tes(notnan_reldiffpv3_stats[1],
                                                  threshold_diff)
            total_test_result["slice" + pslice] = {tested_quantity: result}

        # PLOTS
        if show_figs or save_figs:
            # set the common variables
            main_title = filt + "   " + grat + "   SLICE=" + pslice + "\n"
            bins = 15  # binning for the histograms, if None the function will automatically calculate them
            #             lolim_x, uplim_x, lolim_y, uplim_y
            plt_origin = None

            # Wavelength
            title = main_title + r"Relative wavelength difference = $\Delta \lambda$" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta \lambda$ = ($\lambda_{pipe} - \lambda_{ESA}) / \lambda_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pwave_stats]
            if notnan_rel_diff_pwave_stats[1] is np.nan:
                msg = "Unable to create plot of relative wavelength difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_wave_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pslice + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        plt_name = os.path.join(
                            os.getcwd(),
                            pslice + "_" + det + specific_plt_name)
                        print(
                            "No output_directory was provided. Figures will be saved in current working directory:"
                        )
                        print(plt_name + "\n")
                auxfunc.plt_two_2Dimgandhist(rel_diff_pwave_img,
                                             notnan_rel_diff_pwave,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # Slit-y
            title = main_title + r"Relative slit position = $\Delta$slit_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$slit_y = (slit_y$_{pipe}$ - slit_y$_{ESA}$)/slit_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_rel_diff_pslity_stats]
            if notnan_rel_diff_pslity_stats[1] is np.nan:
                msg = "Unable to create plot of relative slit-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_slitY_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pslice + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(rel_diff_pslity_img,
                                             notnan_rel_diff_pslity,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-x
            title = main_title + r"Relative MSA-x Difference = $\Delta$MSA_x" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_x = (MSA_x$_{pipe}$ - MSA_x$_{ESA}$)/MSA_x$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsax_stats]
            if notnan_reldiffpmsax_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-x difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAx_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pslice + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsax_img,
                                             notnan_reldiffpmsax,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            # MSA-y
            title = main_title + r"Relative MSA-y Difference = $\Delta$MSA_y" + "\n"
            info_img = [title, "x (pixels)", "y (pixels)"]
            xlabel, ylabel = r"Relative $\Delta$MSA_y = (MSA_y$_{pipe}$ - MSA_y$_{ESA}$)/MSA_y$_{ESA}$", "N"
            info_hist = [xlabel, ylabel, bins, notnan_reldiffpmsay_stats]
            if notnan_reldiffpmsay_stats[1] is np.nan:
                msg = "Unable to create plot of relative MSA-y difference."
                print(msg)
                log_msgs.append(msg)
            else:
                specific_plt_name = "_rel_MSAy_diffs.png"
                if isinstance(infile_name, str):
                    plt_name = infile_name.replace(
                        basenameinfile_name,
                        pslice + "_" + det + specific_plt_name)
                else:
                    if output_directory is not None:
                        plt_name = os.path.join(
                            output_directory,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        plt_name = None
                        save_figs = False
                        print(
                            "No output_directory was provided. Figures will NOT be saved."
                        )
                auxfunc.plt_two_2Dimgandhist(reldiffpmsay_img,
                                             notnan_reldiffpmsay,
                                             info_img,
                                             info_hist,
                                             plt_name=plt_name,
                                             plt_origin=plt_origin,
                                             show_figs=show_figs,
                                             save_figs=save_figs)

            if not skipv2v3test:
                # V2
                title = main_title + r"Relative V2 Difference = $\Delta$V2" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V2 = (V2$_{pipe}$ - V2$_{ESA}$)/V2$_{ESA}$", "N"
                hist_data = notnan_reldiffpv2
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv2_stats]
                if notnan_reldiffpv2_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V2 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V2_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pslice + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv2_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

                # V3
                title = main_title + r"Relative V3 Difference = $\Delta$V3" + "\n"
                info_img = [title, "x (pixels)", "y (pixels)"]
                xlabel, ylabel = r"Relative $\Delta$V3 = (V3$_{pipe}$ - V3$_{ESA}$)/V3$_{ESA}$", "N"
                hist_data = notnan_reldiffpv3
                info_hist = [xlabel, ylabel, bins, notnan_reldiffpv3_stats]
                if notnan_reldiffpv3_stats[1] is np.nan:
                    msg = "Unable to create plot of relative V3 difference."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    specific_plt_name = "_rel_V3_diffs.png"
                    if isinstance(infile_name, str):
                        plt_name = infile_name.replace(
                            basenameinfile_name,
                            pslice + "_" + det + specific_plt_name)
                    else:
                        if output_directory is not None:
                            plt_name = os.path.join(
                                output_directory,
                                pslice + "_" + det + specific_plt_name)
                        else:
                            plt_name = None
                            save_figs = False
                            print(
                                "No output_directory was provided. Figures will NOT be saved."
                            )
                    auxfunc.plt_two_2Dimgandhist(reldiffpv3_img,
                                                 hist_data,
                                                 info_img,
                                                 info_hist,
                                                 plt_name=plt_name,
                                                 plt_origin=plt_origin,
                                                 show_figs=show_figs,
                                                 save_figs=save_figs)

        else:
            msg = "NO plots were made because show_figs and save_figs were both set to False. \n"
            print(msg)
            log_msgs.append(msg)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = "FAILED"
    for sl, testdir in total_test_result.items():
        for t, tr in testdir.items():
            if tr == "FAILED":
                FINAL_TEST_RESULT = "FAILED"
                msg = "\n * The test of " + t + " for slice " + sl + " FAILED."
                print(msg)
                log_msgs.append(msg)
            else:
                FINAL_TEST_RESULT = "PASSED"
                msg = "\n * The test of " + t + " for slice " + sl + " PASSED."
                print(msg)
                log_msgs.append(msg)

    if FINAL_TEST_RESULT == "PASSED":
        msg = "\n *** Final result for assign_wcs test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
    else:
        msg = "\n *** Final result for assign_wcs test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)

    return FINAL_TEST_RESULT, log_msgs
Exemple #34
0
def test_functional_ifu_prism():
    """Compare Nirspec instrument model with IDT model for IFU prism."""
    # setup test
    model_file = 'ifu_prism_functional_ESA_v1_20180619.txt'
    hdu1 = create_nirspec_ifu_file(grating='PRISM',
                                   filter='CLEAR',
                                   gwa_xtil=0.35986012,
                                   gwa_ytil=0.13448857,
                                   gwa_tilt=37.1)
    im = datamodels.ImageModel(hdu1)
    refs = create_reference_files(im)
    pipeline = nirspec.create_pipeline(im, refs, slit_y_range=[-0.55, 0.55])
    w = wcs.WCS(pipeline)
    im.meta.wcs = w
    slit_wcs = nirspec.nrs_wcs_set_input(im, 0)  # use slice 0
    ins_file = get_file_path(model_file)
    ins_tab = table.Table.read(ins_file, format='ascii')
    slitx = [0] * 5
    slity = [-.5, -.25, 0, .25, .5]
    lam = np.array([.7e-7, 1e-6, 2e-6, 3e-6, 5e-6])
    order, wrange = nirspec.get_spectral_order_wrange(im,
                                                      refs['wavelengthrange'])
    im.meta.wcsinfo.sporder = order
    im.meta.wcsinfo.waverange_start = wrange[0]
    im.meta.wcsinfo.waverange_end = wrange[1]

    # Slit to MSA entrance
    # This includes the Slicer transform and the IFUFORE transform
    slit2msa = slit_wcs.get_transform('slit_frame', 'msa_frame')
    msax, msay, _ = slit2msa(slitx, slity, lam)
    assert_allclose(slitx, ins_tab['xslitpos'])
    assert_allclose(slity, ins_tab['yslitpos'])
    assert_allclose(msax + 0.0073, ins_tab['xmsapos'],
                    rtol=1e-2)  # expected offset
    assert_allclose(msay + 0.0085, ins_tab['ymaspos'],
                    rtol=1e-2)  # expected offset

    # Slicer
    slit2slicer = slit_wcs.get_transform('slit_frame', 'slicer')
    x_slicer, y_slicer, _ = slit2slicer(slitx, slity, lam)

    # MSA exit
    # Applies the IFUPOST transform to coordinates at the Slicer
    with datamodels.IFUPostModel(refs['ifupost']) as ifupost:
        ifupost_transform = nirspec._create_ifupost_transform(ifupost.slice_0)
    x_msa_exit, y_msa_exit = ifupost_transform(x_slicer, y_slicer, lam)
    assert_allclose(x_msa_exit, ins_tab['xmsapos'])
    assert_allclose(y_msa_exit, ins_tab['ymaspos'])

    # Coordinates at Collimator exit
    # Applies the Collimator forward transform to coordinates at the MSA exit
    with datamodels.open(refs['collimator']) as col:
        colx, coly = col.model.inverse(x_msa_exit, y_msa_exit)
    assert_allclose(colx, ins_tab['xcoll'])
    assert_allclose(coly, ins_tab['ycoll'])

    # After applying direcitonal cosines
    dircos = trmodels.Unitless2DirCos()
    xcolDircosi, ycolDircosi, z = dircos(colx, coly)
    assert_allclose(xcolDircosi, ins_tab['xcolDirCosi'])
    assert_allclose(ycolDircosi, ins_tab['ycolDirCosi'])

    # Slit to GWA entrance
    # applies the Collimator forward, Unitless to Directional and 3D Rotation to MSA exit coordinates
    with datamodels.DisperserModel(refs['disperser']) as disp:
        disperser = nirspec.correct_tilt(disp, im.meta.instrument.gwa_xtilt,
                                         im.meta.instrument.gwa_ytilt)
    collimator2gwa = nirspec.collimator_to_gwa(refs, disperser)
    x_gwa_in, y_gwa_in, z_gwa_in = collimator2gwa(x_msa_exit, y_msa_exit)
    assert_allclose(x_gwa_in, ins_tab['xdispIn'])

    # Slit to GWA out
    # Runs slit--> slicer --> msa_exit --> collimator --> dircos --> rotation --> angle_from_grating equation
    slit2gwa = slit_wcs.get_transform('slit_frame', 'gwa')
    x_gwa_out, y_gwa_out, z_gwa_out = slit2gwa(slitx, slity, lam)
    assert_allclose(x_gwa_out, ins_tab['xdispLaw'])
    assert_allclose(y_gwa_out, ins_tab['ydispLaw'])

    # CAMERA entrance (assuming direction is from sky to detector)
    angles = [
        disperser['theta_x'], disperser['theta_y'], disperser['theta_z'],
        disperser['tilt_y']
    ]
    rotation = trmodels.Rotation3DToGWA(angles,
                                        axes_order="xyzy",
                                        name='rotation')
    dircos2unitless = trmodels.DirCos2Unitless()
    gwa2cam = rotation.inverse | dircos2unitless
    x_camera_entrance, y_camera_entrance = gwa2cam(x_gwa_out, y_gwa_out,
                                                   z_gwa_out)
    assert_allclose(x_camera_entrance, ins_tab['xcamCosi'])
    assert_allclose(y_camera_entrance, ins_tab['ycamCosi'])

    # at FPA
    with datamodels.CameraModel(refs['camera']) as camera:
        x_fpa, y_fpa = camera.model.inverse(x_camera_entrance,
                                            y_camera_entrance)
    assert_allclose(x_fpa, ins_tab['xfpapos'])
    assert_allclose(y_fpa, ins_tab['yfpapos'])

    # at SCA
    slit2sca = slit_wcs.get_transform('slit_frame', 'sca')
    x_sca_nrs1, y_sca_nrs1 = slit2sca(slitx, slity, lam)

    # At NRS2
    with datamodels.FPAModel(refs['fpa']) as fpa:
        x_sca_nrs2, y_sca_nrs2 = fpa.nrs2_model.inverse(x_fpa, y_fpa)
    assert_allclose(x_sca_nrs1 + 1, ins_tab['i'])
    assert_allclose(y_sca_nrs1 + 1, ins_tab['j'])

    # at oteip
    # Goes through slicer, ifufore, and fore transforms
    slit2oteip = slit_wcs.get_transform('slit_frame', 'oteip')
    x_oteip, y_oteip, _ = slit2oteip(slitx, slity, lam)
    assert_allclose(x_oteip, ins_tab['xOTEIP'])
    assert_allclose(y_oteip, ins_tab['yOTEIP'])

    # at v2, v3 [in arcsec]
    slit2v23 = slit_wcs.get_transform('slit_frame', 'v2v3')
    v2, v3, _ = slit2v23(slitx, slity, lam)
    v2 /= 3600
    v3 /= 3600
    assert_allclose(v2, ins_tab['xV2V3'])
    assert_allclose(v3, ins_tab['yV2V3'])
Exemple #35
0
def do_correction(input_model, pathloss_model):
    """
    Short Summary
    -------------
    Execute all tasks for Path Loss Correction

    Parameters
    ----------
    input_model : data model object
        science data to be corrected

    pathloss_model : pathloss model object
        pathloss correction data

    Returns
    -------
    output_model : data model object
        Corrected science data with pathloss extensions added

    """
    exp_type = input_model.meta.exposure.type
    log.info(f'Input exposure type is {exp_type}')
    output_model = input_model.copy()

    # NIRSpec MOS data
    if exp_type == 'NRS_MSASPEC':
        slit_number = 0

        # Loop over all MOS slitlets
        for slit in output_model.slits:
            slit_number = slit_number + 1
            log.info(f'Working on slit {slit_number}')
            size = slit.data.size

            # Only work on slits with data.size > 0
            if size > 0:

                # Get centering
                xcenter, ycenter = get_center(exp_type, slit)
                # Calculate the 1-d wavelength and pathloss vectors
                # for the source position
                # Get the aperture from the reference file that matches the slit
                nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
                aperture = get_aperture_from_model(pathloss_model, nshutters)
                if aperture is not None:
                    (wavelength_pointsource, pathloss_pointsource_vector,
                     is_inside_slitlet) = calculate_pathloss_vector(
                         aperture.pointsource_data, aperture.pointsource_wcs,
                         xcenter, ycenter)
                    (wavelength_uniformsource, pathloss_uniform_vector,
                     dummy) = calculate_pathloss_vector(
                         aperture.uniform_data, aperture.uniform_wcs, xcenter,
                         ycenter)
                    if is_inside_slitlet:

                        # Wavelengths in the reference file are in meters,
                        # need them to be in microns
                        wavelength_pointsource *= 1.0e6
                        wavelength_uniformsource *= 1.0e6

                        wavelength_array = slit.wavelength

                        # Compute the point source pathloss 2D correction
                        pathloss_2d_ps = interpolate_onto_grid(
                            wavelength_array, wavelength_pointsource,
                            pathloss_pointsource_vector)

                        # Compute the uniform source pathloss 2D correction
                        pathloss_2d_un = interpolate_onto_grid(
                            wavelength_array, wavelength_uniformsource,
                            pathloss_uniform_vector)

                        # Use the appropriate correction for this slit
                        if is_pointsource(slit.source_type):
                            pathloss_2d = pathloss_2d_ps
                        else:
                            pathloss_2d = pathloss_2d_un

                        # Apply the pathloss 2D correction and attach to datamodel
                        slit.data /= pathloss_2d
                        slit.err /= pathloss_2d
                        slit.var_poisson /= pathloss_2d**2
                        slit.var_rnoise /= pathloss_2d**2
                        if slit.var_flat is not None and np.size(
                                slit.var_flat) > 0:
                            slit.var_flat /= pathloss_2d**2
                        slit.pathloss_point = pathloss_2d_ps
                        slit.pathloss_uniform = pathloss_2d_un
                    else:
                        log.warning(
                            "Source is outside slit. Skipping "
                            f"pathloss correction for slit {slit_number}")
                else:
                    log.warning(
                        "Cannot find matching pathloss model for slit with"
                        f"{nshutters} shutters")
                    log.warning("Skipping pathloss correction for this slit")
                    continue
            else:
                log.warning(f"Slit has data size = {size}")
                log.warning("Skipping pathloss correction for this slitlet")

        # Set step status to complete
        output_model.meta.cal_step.pathloss = 'COMPLETE'

    # NIRSpec fixed-slit data
    elif exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
        slit_number = 0
        is_inside_slit = True

        # Loop over all slits contained in the input
        for slit in output_model.slits:
            log.info(f'Working on slit {slit.name}')
            slit_number = slit_number + 1

            # Get centering
            xcenter, ycenter = get_center(exp_type, slit)
            # Calculate the 1-d wavelength and pathloss vectors for the source position
            # Get the aperture from the reference file that matches the slit
            aperture = get_aperture_from_model(pathloss_model, slit.name)

            if aperture is not None:
                log.info(f'Using aperture {aperture.name}')
                (wavelength_pointsource, pathloss_pointsource_vector,
                 is_inside_slit) = calculate_pathloss_vector(
                     aperture.pointsource_data, aperture.pointsource_wcs,
                     xcenter, ycenter)
                (wavelength_uniformsource, pathloss_uniform_vector,
                 dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                    aperture.uniform_wcs,
                                                    xcenter, ycenter)
                if is_inside_slit:

                    # Wavelengths in the reference file are in meters,
                    # need them to be in microns
                    wavelength_pointsource *= 1.0e6
                    wavelength_uniformsource *= 1.0e6

                    wavelength_array = slit.wavelength

                    # Compute the point source pathloss 2D correction
                    pathloss_2d_ps = interpolate_onto_grid(
                        wavelength_array, wavelength_pointsource,
                        pathloss_pointsource_vector)

                    # Compute the uniform source pathloss 2D correction
                    pathloss_2d_un = interpolate_onto_grid(
                        wavelength_array, wavelength_uniformsource,
                        pathloss_uniform_vector)

                    # Use the appropriate correction for this slit
                    if is_pointsource(slit.source_type):
                        pathloss_2d = pathloss_2d_ps
                    else:
                        pathloss_2d = pathloss_2d_un

                    # Apply the pathloss 2D correction and attach to datamodel
                    slit.data /= pathloss_2d
                    slit.err /= pathloss_2d
                    slit.var_poisson /= pathloss_2d**2
                    slit.var_rnoise /= pathloss_2d**2
                    if slit.var_flat is not None and np.size(
                            slit.var_flat) > 0:
                        slit.var_flat /= pathloss_2d**2
                    slit.pathloss_point = pathloss_2d_ps
                    slit.pathloss_uniform = pathloss_2d_un

                else:
                    log.warning('Source is outside slit. Skipping '
                                f'pathloss correction for slit {slit.name}')
            else:
                log.warning(
                    f'Cannot find matching pathloss model for {slit.name}')
                log.warning('Skipping pathloss correction for this slit')
                continue

        # Set step status to complete
        output_model.meta.cal_step.pathloss = 'COMPLETE'

    # NIRSpec IFU
    elif exp_type == 'NRS_IFU':
        # IFU targets are always inside slit
        # Get centering
        xcenter, ycenter = get_center(exp_type, None)
        # Calculate the 1-d wavelength and pathloss vectors for the source position
        aperture = pathloss_model.apertures[0]
        (wavelength_pointsource, pathloss_pointsource_vector,
         dummy) = calculate_pathloss_vector(aperture.pointsource_data,
                                            aperture.pointsource_wcs, xcenter,
                                            ycenter)
        (wavelength_uniformsource, pathloss_uniform_vector,
         dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                            aperture.uniform_wcs, xcenter,
                                            ycenter)
        # Wavelengths in the reference file are in meters;
        # need them to be in microns
        wavelength_pointsource *= 1.0e6
        wavelength_uniformsource *= 1.0e6

        # Create the 2-d wavelength arrays, initialize with NaNs
        wavelength_array = np.zeros(input_model.shape, dtype=np.float32)
        wavelength_array.fill(np.nan)
        for slice in NIRSPEC_IFU_SLICES:
            slice_wcs = nirspec.nrs_wcs_set_input(input_model, slice)
            x, y = wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
            ra, dec, wavelength = slice_wcs(x, y)
            valid = ~np.isnan(wavelength)
            x = x[valid]
            y = y[valid]
            wavelength_array[y.astype(int), x.astype(int)] = wavelength[valid]

        # Compute the point source pathloss 2D correction
        pathloss_2d_ps = interpolate_onto_grid(wavelength_array,
                                               wavelength_pointsource,
                                               pathloss_pointsource_vector)

        # Compute the uniform source pathloss 2D correction
        pathloss_2d_un = interpolate_onto_grid(wavelength_array,
                                               wavelength_uniformsource,
                                               pathloss_uniform_vector)

        # Use the appropriate correction for the source type
        if is_pointsource(input_model.meta.target.source_type):
            pathloss_2d = pathloss_2d_ps
        else:
            pathloss_2d = pathloss_2d_un

        # Apply the pathloss 2D correction and attach to datamodel
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        if output_model.var_flat is not None and np.size(
                output_model.var_flat) > 0:
            output_model.var_flat /= pathloss_2d**2
        output_model.pathloss_point = pathloss_2d_ps
        output_model.pathloss_uniform = pathloss_2d_un

        # This might be useful to other steps
        output_model.wavelength = wavelength_array

        # Set the step status to complete
        output_model.meta.cal_step.pathloss = 'COMPLETE'

    # NIRISS SOSS
    elif exp_type == 'NIS_SOSS':
        """NIRISS SOSS pathloss correction is basically a correction for the
        flux from the 2nd and 3rd order dispersion that falls outside the
        subarray aperture.  The correction depends on the pupil wheel position
        and column number (or wavelength).  The simple option is to do the
        correction by column number, then the only interpolation needed is a
        1-d interpolation into the pupil wheel position dimension.
        """

        # Omit correction if this is a TSO observation
        if input_model.meta.visit.tsovisit:
            log.warning("NIRISS SOSS TSO observations skip the pathloss step")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        # Get the pupil wheel position
        pupil_wheel_position = input_model.meta.instrument.pupil_position
        if pupil_wheel_position is None:
            log.warning(
                'Unable to get pupil wheel position from PWCPOS keyword '
                f'for {input_model.meta.filename}')
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        # Get the aperture from the reference file that matches the subarray
        subarray = input_model.meta.subarray.name
        aperture = get_aperture_from_model(pathloss_model, subarray)
        if aperture is None:
            log.warning('Unable to get Aperture from reference file '
                        f'for subarray {subarray}')
            log.warning("Pathloss correction skipped")
            output_model.meta.cal_step.pathloss = 'SKIPPED'
            return output_model

        else:
            log.info(f'Aperture {aperture.name} selected from reference file')

        # Set up pathloss correction array
        pathloss_array = aperture.pointsource_data[0]
        nrows, ncols = pathloss_array.shape
        _, data_ncols = input_model.data.shape
        correction = np.ones(data_ncols, dtype=np.float32)
        crpix1 = aperture.pointsource_wcs.crpix1
        crval1 = aperture.pointsource_wcs.crval1
        cdelt1 = aperture.pointsource_wcs.cdelt1
        pupil_wheel_index = crpix1 + (pupil_wheel_position -
                                      crval1) / cdelt1 - 1

        if pupil_wheel_index < 0 or pupil_wheel_index > (ncols - 2):
            log.warning("Pupil Wheel position outside reference file coverage")
            log.warning("Setting pathloss correction to 1.0")
        else:
            ix = int(pupil_wheel_index)
            dx = pupil_wheel_index - ix
            crpix2 = aperture.pointsource_wcs.crpix2
            crval2 = aperture.pointsource_wcs.crval2
            cdelt2 = aperture.pointsource_wcs.cdelt2
            for row in range(data_ncols):
                row_1indexed = row + 1
                refrow_index = math.floor(crpix2 +
                                          (row_1indexed - crval2) / cdelt2 -
                                          0.5)
                if refrow_index < 0 or refrow_index > (nrows - 1):
                    correction[row] = 1.0
                else:
                    correction[row] = (1.0 - dx) * pathloss_array[refrow_index, ix] + \
                                      dx * pathloss_array[refrow_index, ix + 1]

        # Create and apply the 2D correction
        pathloss_2d = np.broadcast_to(correction, input_model.data.shape)
        output_model.data /= pathloss_2d
        output_model.err /= pathloss_2d
        output_model.var_poisson /= pathloss_2d**2
        output_model.var_rnoise /= pathloss_2d**2
        if output_model.var_flat is not None and np.size(
                output_model.var_flat) > 0:
            output_model.var_flat /= pathloss_2d**2
        output_model.pathloss_point = pathloss_2d

        # Set step status to complete
        output_model.meta.cal_step.pathloss = 'COMPLETE'

    return output_model