예제 #1
0
def slit_data_a():
    """Create "science" data for testing.

    Returns
    -------
    input_model : `~jwst.datamodels.MultiSlitModel`
    """

    # Create a MultiSlitModel object.
    data_shape = (5, 9)
    data = np.zeros(data_shape, dtype=np.float32) + 10.
    dq = np.zeros(data_shape, dtype=np.uint32)
    # One row of wavelengths.
    temp_wl = np.linspace(1.3,
                          4.8,
                          num=data_shape[1],
                          endpoint=True,
                          retstep=False,
                          dtype=np.float32)
    wavelength = np.zeros(data_shape, dtype=np.float32)
    # Shift the wavelength values from row to the next.
    dwl = 0.1
    for j in range(data_shape[0]):
        wavelength[j, :] = (temp_wl + j * dwl)
    wavelength = np.around(wavelength, 4)
    input_model = datamodels.MultiSlitModel()
    slit = datamodels.SlitModel(init=None,
                                data=data,
                                dq=dq,
                                wavelength=wavelength)
    input_model.slits.append(slit)

    return input_model
예제 #2
0
def test_has_uniform_source():

    data = np.zeros((10, 100), dtype=np.float32)
    slitlet = datamodels.SlitModel(data=data)

    # Since source_type has not been set yet, the barshadow step will
    # assume that the source is extended.
    assert bar.has_uniform_source(slitlet)

    slitlet.source_type = 'POINT'
    assert not bar.has_uniform_source(slitlet)  # not extended

    slitlet.source_type = 'UNKNOWN'
    # Since source_type is not 'POINT', the step will assume that the
    # source is extended.
    assert bar.has_uniform_source(slitlet)
예제 #3
0
def test_fs_correction():
    """Test application of FS corrections"""

    data = np.ones((5, 5))
    ff_ps = 1.5 * data
    ff_un = data / 1.2
    pl_ps = 2 * data
    pl_un = data / 2.1
    ph_ps = 1.1 * data
    ph_un = 1.23 * data
    input = datamodels.SlitModel(data=data,
                                 flatfield_point=ff_ps, flatfield_uniform=ff_un,
                                 pathloss_point=pl_ps, pathloss_uniform=pl_un,
                                 photom_point=ph_ps, photom_uniform=ph_un)

    corrected = input.data * (ff_un / ff_ps) * (pl_un / pl_ps) * (ph_ps / ph_un)
    result = correct_nrs_fs_bkg(input, primary_slit=True)

    assert np.allclose(corrected, result.data, rtol=1.e-7)
예제 #4
0
def _corrections_for_fixedslit(slit, pathloss, exp_type, source_type):
    """Calculate the correction arrasy for Fixed-slit slit

    Parameters
    ----------
    slit : jwst.datamodels.SlitModel
        The slit being operated on.

    pathloss : jwst.datamodels.DataModel
        The pathloss reference data

    exp_type : str
        Exposure type

    source_type : str or None
        Force processing using the specified source type.

    Returns
    -------
    correction : jwst.datamodels.SlitModel
        The correction arrays
    """
    correction = None

    # Get centering
    xcenter, ycenter = get_center(exp_type, slit)
    # Calculate the 1-d wavelength and pathloss vectors for the source position
    # Get the aperture from the reference file that matches the slit
    aperture = get_aperture_from_model(pathloss, slit.name)

    if aperture is not None:
        log.info(f'Using aperture {aperture.name}')
        (wavelength_pointsource, pathloss_pointsource_vector,
         is_inside_slit) = calculate_pathloss_vector(aperture.pointsource_data,
                                                     aperture.pointsource_wcs,
                                                     xcenter, ycenter)
        (wavelength_uniformsource, pathloss_uniform_vector,
         dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                            aperture.uniform_wcs, xcenter,
                                            ycenter)
        if is_inside_slit:

            # Wavelengths in the reference file are in meters,
            # need them to be in microns
            wavelength_pointsource *= 1.0e6
            wavelength_uniformsource *= 1.0e6

            wavelength_array = slit.wavelength

            # Compute the point source pathloss 2D correction
            pathloss_2d_ps = interpolate_onto_grid(
                wavelength_array, wavelength_pointsource,
                pathloss_pointsource_vector)

            # Compute the uniform source pathloss 2D correction
            pathloss_2d_un = interpolate_onto_grid(wavelength_array,
                                                   wavelength_uniformsource,
                                                   pathloss_uniform_vector)

            # Use the appropriate correction for this slit
            if is_pointsource(source_type or slit.source_type):
                pathloss_2d = pathloss_2d_ps
            else:
                pathloss_2d = pathloss_2d_un

            # Save the corrections. The `data` portion is the correction used.
            # The individual ones will be saved in the respective attributes.
            correction = datamodels.SlitModel(data=pathloss_2d)
            correction.pathloss_point = pathloss_2d_ps
            correction.pathloss_uniform = pathloss_2d_un

        else:
            log.warning('Source is outside slit. Skipping '
                        f'pathloss correction for slit {slit.name}')
    else:
        log.warning(f'Cannot find matching pathloss model for {slit.name}')
        log.warning('Skipping pathloss correction for this slit')

    return correction
예제 #5
0
def _corrections_for_mos(slit, pathloss, exp_type, source_type=None):
    """Calculate the correction arrasy for MOS slit

    Parameters
    ----------
    slit : jwst.datamodels.SlitModel
        The slit being operated on.

    pathloss : jwst.datamodels.DataModel
        The pathloss reference data

    exp_type : str
        Exposure type

    source_type : str or None
        Force processing using the specified source type.

    Returns
    -------
    correction : jwst.datamodels.SlitModel
        The correction arrays
    """
    correction = None
    size = slit.data.size

    # Only work on slits with data.size > 0
    if size > 0:

        # Get centering
        xcenter, ycenter = get_center(exp_type, slit)
        # Calculate the 1-d wavelength and pathloss vectors
        # for the source position
        # Get the aperture from the reference file that matches the slit
        nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
        aperture = get_aperture_from_model(pathloss, nshutters)
        if aperture is not None:
            (wavelength_pointsource, pathloss_pointsource_vector,
             is_inside_slitlet) = calculate_pathloss_vector(
                 aperture.pointsource_data, aperture.pointsource_wcs, xcenter,
                 ycenter)
            (wavelength_uniformsource, pathloss_uniform_vector,
             dummy) = calculate_pathloss_vector(aperture.uniform_data,
                                                aperture.uniform_wcs, xcenter,
                                                ycenter)
            if is_inside_slitlet:

                # Wavelengths in the reference file are in meters,
                # need them to be in microns
                wavelength_pointsource *= 1.0e6
                wavelength_uniformsource *= 1.0e6

                wavelength_array = slit.wavelength

                # Compute the point source pathloss 2D correction
                pathloss_2d_ps = interpolate_onto_grid(
                    wavelength_array, wavelength_pointsource,
                    pathloss_pointsource_vector)

                # Compute the uniform source pathloss 2D correction
                pathloss_2d_un = interpolate_onto_grid(
                    wavelength_array, wavelength_uniformsource,
                    pathloss_uniform_vector)

                # Use the appropriate correction for this slit
                if is_pointsource(source_type or slit.source_type):
                    pathloss_2d = pathloss_2d_ps
                else:
                    pathloss_2d = pathloss_2d_un

                # Save the corrections. The `data` portion is the correction used.
                # The individual ones will be saved in the respective attributes.
                correction = datamodels.SlitModel(data=pathloss_2d)
                correction.pathloss_point = pathloss_2d_ps
                correction.pathloss_uniform = pathloss_2d_un
            else:
                log.warning("Source is outside slit.")
        else:
            log.warning("Cannot find matching pathloss model for slit with"
                        f"{nshutters} shutters")
    else:
        log.warning(f"Slit has data size = {size}")

    return correction
예제 #6
0
def _calc_correction(slitlet, barshadow_model, source_type):
    """Calculate the barshadow correction for a slitlet

    Parameters
    ----------
    slitlet : jwst.datamodels.SlitModel
        The slitlet to calculate for.

    barshadow_model : `~jwst.datamodels.BarshadowModel`
        bar shadow data model from reference file

    source_type : str or None
        Force processing using the specified source type.

    Returns
    -------
    correction : jwst.datamodels.SlitModel
        The correction to be applied
    """
    slitlet_number = slitlet.slitlet_id

    # Create the pieces that are put together to make the barshadow model
    shutter_elements = create_shutter_elements(barshadow_model)
    w0 = barshadow_model.crval1
    wave_increment = barshadow_model.cdelt1
    y_increment = barshadow_model.cdelt2
    shutter_height = 1.0 / y_increment

    # The correction only applies to extended/uniform sources
    correction = datamodels.SlitModel(data=np.ones(slitlet.data.shape))
    if has_uniform_source(slitlet, source_type):
        shutter_status = slitlet.shutter_state
        if len(shutter_status) > 0:
            shadow = create_shadow(shutter_elements, shutter_status)

            # For each pixel in the slit subarray,
            # make a grid of indices for pixels in the subarray
            x, y = wcstools.grid_from_bounding_box(
                slitlet.meta.wcs.bounding_box, step=(1, 1))

            # Create the transformation from slit_frame to detector
            det2slit = slitlet.meta.wcs.get_transform('detector', 'slit_frame')

            # Use this transformation to calculate x, y, and wavelength
            xslit, yslit, wavelength = det2slit(x, y)

            # If the source position is off-center in the slit, renormalize the yslit
            # values so that it appears as if the source is centered, which is the appropriate
            # way to compute the shadow correction for extended/uniform sources (doesn't
            # depend on source location).
            if len(shutter_status) > 1:
                middle = (len(shutter_status) - 1) / 2.0
                src_loc = shutter_status.find('x')
                if src_loc != -1 and float(src_loc) != middle:
                    yslit -= np.nanmean(yslit)

            # The returned y values are scaled to where the slit height is 1
            # (i.e. a slit goes from -0.5 to 0.5).  The barshadow array is scaled
            # so that the separation between the slit centers is 1,
            # i.e. slit height + interslit bar
            yslit = yslit / SLITRATIO

            # Convert the Y and wavelength to a pixel location in the  bar shadow array;
            # the fiducial should always be at the center of the slitlet, regardless of
            # where the source is centered.
            index_of_fiducial = (len(shutter_status) - 1) / 2.0

            # The shutters go downwards, i.e. the first shutter in shutter_status corresponds to
            # the last in the shadow array.  So the center of the first shutter referred to in
            # shutter_status has an index of shadow.shape[0] - shutter_height.  Each subsequent
            # shutter center has an index shutter_height greater.
            index_of_fiducial_in_array = shadow.shape[0] - shutter_height * (
                1 + index_of_fiducial)
            yrow = index_of_fiducial_in_array + yslit * shutter_height
            wcol = (wavelength - w0) / wave_increment

            # Interpolate the bar shadow correction for non-Nan pixels
            correction.data = interpolate(yrow, wcol, shadow)
        else:
            log.info("Slitlet %d has zero length, correction skipped" %
                     slitlet_number)
    else:
        log.info(
            "Bar shadow correction skipped for slitlet %d (source not uniform)"
            % slitlet_number)

    return correction
예제 #7
0
def contam_corr(input_model, waverange, photom, max_cores):
    """
    The main WFSS contamination correction function

    Parameters
    ----------
    input_model : `~jwst.datamodels.MultiSlitModel`
        Input data model containing 2D spectral cutouts
    waverange : `~jwst.datamodels.WavelengthrangeModel`
        Wavelength range reference file model
    photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel`
        Photom (flux cal) reference file model
    max_cores : string
        Number of cores to use for multiprocessing. If set to 'none'
        (the default), then no multiprocessing will be done. The other
        allowable values are 'quarter', 'half', and 'all', which indicate
        the fraction of cores to use for multi-proc. The total number of
        cores includes the SMT cores (Hyper Threading for Intel).

    Returns
    -------
    output_model : `~jwst.datamodels.MultiSlitModel`
        A copy of the input_model that has been corrected
    simul_model : `~jwst.datamodels.ImageModel`
        Full-frame simulated image of the grism exposure
    contam_model : `~jwst.datamodels.MultiSlitModel`
        Contamination estimate images for each source slit

    """
    # Determine number of cpu's to use for multi-processing
    if max_cores == 'none':
        ncpus = 1
    else:
        num_cores = multiprocessing.cpu_count()
        if max_cores == 'quarter':
            ncpus = num_cores // 4 or 1
        elif max_cores == 'half':
            ncpus = num_cores // 2 or 1
        elif max_cores == 'all':
            ncpus = num_cores
        else:
            ncpus = 1
        log.debug(f"Found {num_cores} cores; using {ncpus}")

    # Initialize output model
    output_model = input_model.copy()

    # Get the segmentation map for this grism exposure
    seg_model = datamodels.open(input_model.meta.segmentation_map)

    # Get the direct image from which the segmentation map was constructed
    direct_file = input_model.meta.direct_image
    image_names = [direct_file]
    log.debug(f"Direct image names={image_names}")

    # Get the grism WCS from the input model
    grism_wcs = input_model.slits[0].meta.wcs

    # Find out how many spectral orders are defined, based on the
    # array of order values in the Wavelengthrange ref file
    spec_orders = np.asarray(waverange.order)
    spec_orders = spec_orders[spec_orders != 0]  # ignore any order 0 entries
    log.debug(f"Spectral orders defined = {spec_orders}")

    # Get the FILTER and PUPIL wheel positions, for use later
    filter_kwd = input_model.meta.instrument.filter
    pupil_kwd = input_model.meta.instrument.pupil

    # NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel
    # with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just
    # the opposite. It has gratings in the FILTER wheel and filters in the
    # PUPIL wheel. So when processing NIRISS grism exposures the name of
    # filter needs to come from the PUPIL keyword value.
    if input_model.meta.instrument.name == 'NIRISS':
        filter_name = pupil_kwd
    else:
        filter_name = filter_kwd

    # Load lists of wavelength ranges and flux cal info for all orders
    wmin = {}
    wmax = {}
    sens_waves = {}
    sens_response = {}
    for order in spec_orders:
        wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order])
        wmin[order] = wavelength_range[order][0]
        wmax[order] = wavelength_range[order][1]
        # Load the sensitivity (inverse flux cal) data for this mode and order
        sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order)
    log.debug(f"wmin={wmin}, wmax={wmax}")

    # Initialize the simulated image object
    simul_all = None
    obs = Observation(image_names, seg_model, grism_wcs, filter_name,
                      boundaries=[0, 2047, 0, 2047], max_cpu=ncpus)

    # Create simulated grism image for each order and sum them up
    for order in spec_orders:

        log.info(f"Creating full simulated grism image for order {order}")
        obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order],
                         sens_response[order])

        # Accumulate result for this order into the combined image
        if simul_all is None:
            simul_all = obs.simulated_image
        else:
            simul_all += obs.simulated_image

    # Save the full-frame simulated grism image
    simul_model = datamodels.ImageModel(data=simul_all)
    simul_model.update(input_model, only="PRIMARY")

    # Loop over all slits/sources to subtract contaminating spectra
    log.info("Creating contamination image for each individual source")
    contam_model = datamodels.MultiSlitModel()
    contam_model.update(input_model)
    slits = []
    for slit in output_model.slits:

        # Create simulated spectrum for this source only
        sid = slit.source_id
        order = slit.meta.wcsinfo.spectral_order
        chunk = np.where(obs.IDs == sid)[0][0]  # find chunk for this source

        obs.simulated_image = np.zeros(obs.dims)
        obs.disperse_chunk(chunk, order, wmin[order], wmax[order],
                           sens_waves[order], sens_response[order])
        this_source = obs.simulated_image

        # Contamination estimate is full simulated image minus this source
        contam = simul_all - this_source

        # Create a cutout of the contam image that matches the extent
        # of the source slit
        x1 = slit.xstart - 1
        x2 = x1 + slit.xsize
        y1 = slit.ystart - 1
        y2 = y1 + slit.ysize
        cutout = contam[y1:y2, x1:x2]
        new_slit = datamodels.SlitModel(data=cutout)
        copy_slit_info(slit, new_slit)
        slits.append(new_slit)

        # Subtract the cutout from the source slit
        slit.data -= cutout

    # Save the contamination estimates for all slits
    contam_model.slits.extend(slits)

    # Set the step status to COMPLETE
    output_model.meta.cal_step.wfss_contam = 'COMPLETE'

    return output_model, simul_model, contam_model