Exemple #1
0
def abl_to_v2v3l(input_model, reference_files):
    """
    Create the transform from (alpha,beta,lambda) to (V2,V3,lambda) frame.

    Parameters
    ----------
    input_model : `jwst.datamodels.ImagingModel`
        Data model.
    reference_files : dict
        Dictionary {reftype: reference file name}.

    forward transform:
      RegionsSelector
        label_mapper is LabelMapperDict()
        {channel_wave_range (): channel_number}
        selector is {channel_number: ab2v2 & ab2v3}
    bacward_transform
      RegionsSelector
        label_mapper is LabelMapperDict()
        {channel_wave_range (): channel_number}
        selector is {channel_number: v22ab & v32ab}
    """
    band = input_model.meta.instrument.band
    channel = input_model.meta.instrument.channel
    # used to read the wavelength range
    channels = [c + band for c in channel]

    with DistortionMRSModel(reference_files['distortion']) as dist:
        v23 = dict(zip(dist.abv2v3_model.channel_band,
                       dist.abv2v3_model.model))

    with WavelengthrangeModel(reference_files['wavelengthrange']) as f:
        wr = dict(zip(f.waverange_selector, f.wavelengthrange))

    dict_mapper = {}
    sel = {}
    # Since there are two channels in each reference file we need to loop over them
    for c in channels:
        ch = int(c[0])
        dict_mapper[tuple(wr[c])] = models.Mapping((2,), name="mapping_lam") | \
                   models.Const1D(ch, name="channel #")
        ident1 = models.Identity(1, name='identity_lam')
        ident1._inputs = ('lam', )
        chan_v23 = v23[c]
        v23chan_backward = chan_v23.inverse
        del chan_v23.inverse
        # This is the spatial part of the transform; tack on additional conversion to degrees
        # Remove this degrees conversion once pipeline can handle v2,v3 in arcsec
        v23_spatial = chan_v23 | models.Scale(1 / 3600) & models.Scale(
            1 / 3600)
        v23_spatial.inverse = models.Scale(3600) & models.Scale(
            3600) | v23chan_backward
        # Tack on passing the third wavelength component
        v23c = v23_spatial & ident1
        sel[ch] = v23c

    wave_range_mapper = selector.LabelMapperRange(
        ('alpha', 'beta', 'lam'),
        dict_mapper,
        inputs_mapping=models.Mapping([
            2,
        ]))
    wave_range_mapper.inverse = wave_range_mapper.copy()
    abl2v2v3l = selector.RegionsSelector(('alpha', 'beta', 'lam'),
                                         ('v2', 'v3', 'lam'),
                                         label_mapper=wave_range_mapper,
                                         selector=sel)

    return abl2v2v3l
Exemple #2
0
def lrs(input_model, reference_files):
    """
    The LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.

    It has two coordinate frames: "detecor" and "world".
    Uses the "specwcs" and "distortion" reference files.

    """
    # Setup the frames.
    detector = cf.Frame2D(name='detector',
                          axes_order=(0, 1),
                          unit=(u.pix, u.pix))
    spec = cf.SpectralFrame(name='wavelength',
                            axes_order=(2, ),
                            unit=(u.micron, ),
                            axes_names=('lambda', ))
    sky = cf.CelestialFrame(reference_frame=coord.ICRS(), name='sky')
    v2v3_spatial = cf.Frame2D(name='v2v3_spatial',
                              axes_order=(0, 1),
                              unit=(u.arcsec, u.arcsec))
    v2v3 = cf.CompositeFrame(name="v2v3", frames=[v2v3_spatial, spec])
    world = cf.CompositeFrame(name="world", frames=[sky, spec])

    # Determine the distortion model.
    subarray2full = subarray_transform(input_model)
    with DistortionModel(reference_files['distortion']) as dist:
        distortion = dist.model

    if subarray2full is not None:
        distortion = subarray2full | distortion

    # Incorporate the small rotation
    angle = np.arctan(0.00421924)
    rotation = models.Rotation2D(angle)
    distortion = distortion | rotation

    # Load and process the reference data.
    with fits.open(reference_files['specwcs']) as ref:
        lrsdata = np.array([l for l in ref[1].data])

        # Get the zero point from the reference data.
        # The zero_point is X, Y  (which should be COLUMN, ROW)
        # TODO: Are imx, imy 0- or 1-indexed?  We are treating them here as
        # 0-indexed.  Since they are FITS, they are probably 1-indexed.
        if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
            zero_point = ref[1].header['imx'], ref[1].header['imy']
        elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
            zero_point = ref[1].header['imxsltl'], ref[1].header['imysltl']
            #zero_point = [35, 442]  # [35, 763] # account for subarray

    # Create the bounding_box
    x0 = lrsdata[:, 3]
    y0 = lrsdata[:, 4]
    x1 = lrsdata[:, 5]

    bb = ((x0.min() - 0.5 + zero_point[0], x1.max() + 0.5 + zero_point[0]),
          (y0.min() - 0.5 + zero_point[1], y0.max() + 0.5 + zero_point[1]))

    # Compute the v2v3 to sky.
    tel2sky = pointing.v23tosky(input_model)

    # To compute the spatial detector to V2V3 transform:
    # Take a row centered on zero_point_y and convert it to v2, v3.
    # The forward transform uses constant ``y`` values for each ``x``.
    # The inverse transform uses constant ``v3`` values for each ``v2``.
    zero_point_v2v3 = distortion(*zero_point)

    spatial_forward = models.Identity(1) & models.Const1D(
        zero_point[1]) | distortion
    spatial_forward.inverse = (
        models.Identity(1) & models.Const1D(zero_point_v2v3[1])
        | distortion.inverse)

    # Create the spectral transforms.
    lrs_wav_model = jwmodels.LRSWavelength(lrsdata, zero_point)

    try:
        velosys = input_model.meta.wcsinfo.velosys
    except AttributeError:
        pass
    else:
        if velosys is not None:
            velocity_corr = velocity_correction(
                input_model.meta.wcsinfo.velosys)
            lrs_wav_model = lrs_wav_model | velocity_corr
            log.info("Applied Barycentric velocity correction : {}".format(
                velocity_corr[1].amplitude.value))

    det_to_v2v3 = models.Mapping(
        (0, 1, 0, 1)) | spatial_forward & lrs_wav_model
    det_to_v2v3.bounding_box = bb[::-1]
    v23_to_world = tel2sky & models.Identity(1)

    # Now the actual pipeline.
    pipeline = [(detector, det_to_v2v3), (v2v3, v23_to_world), (world, None)]
    return pipeline
Exemple #3
0
def align_catalogs(xin,
                   yin,
                   xref,
                   yref,
                   model_guess=None,
                   translation=None,
                   translation_range=None,
                   rotation=None,
                   rotation_range=None,
                   magnification=None,
                   magnification_range=None,
                   tolerance=0.1,
                   center_of_field=None):
    """
    Generic interface for a 2D catalog match. Either an initial model guess
    is provided, or a model will be created using a combination of
    translation, rotation, and magnification, as requested. Only those
    transformations for which a *range* is specified will be used. In order
    to keep the translation close to zero, the rotation and magnification
    are performed around the centre of the field, which can either be provided
    -- as (x,y) in 1-based pixels -- or will be determined from the mid-range
    of the x and y input coordinates.

    Parameters
    ----------
    xin, yin: float arrays
        input coordinates
    xref, yref: float arrays
        reference coordinates to map and match to
    model_guess: Model
        initial model guess (overrides the next parameters)
    translation: 2-tuple of floats
        initial translation guess
    translation_range: None, value, 2-tuple or 2x2-tuple
        None => fixed
        value => search range from initial guess (same for x and y)
        2-tuple => search limits (same for x and y)
        2x2-tuple => search limits for x and y
    rotation: float
        initial rotation guess (degrees)
    rotation_range: None, float, or 2-tuple
        extent of search space for rotation
    magnification: float
        initial magnification factor
    magnification_range: None, float, or 2-tuple
        extent of search space for magnification
    tolerance: float
        accuracy required for final result
    center_of_field: 2-tuple
        rotation and magnification have no effect at this location
         (if None, uses middle of xin,yin ranges)

    Returns
    -------
    Model: a model that maps (xin,yin) to (xref,yref)
    """
    def _get_value_and_range(value, range):
        """Converts inputs to a central value and a range tuple"""
        try:
            r1, r2 = range
        except TypeError:
            r1, r2 = range, None
        except ValueError:
            r1, r2 = None, None
        if value is not None:
            if r1 is not None and r2 is not None:
                if r1 <= value <= r2:
                    return value, (r1, r2)
                else:
                    extent = 0.5 * abs(r2 - r1)
                    return value, (value - extent, value + extent)
            elif r1 is not None:
                return value, (value - r1, value + r1)
            else:
                return value, None
        elif r1 is not None:
            if r2 is None:
                return 0.0, (-r1, r1)
            else:
                return 0.5 * (r1 + r2), (r1, r2)
        else:
            return None, None

    log = logutils.get_logger(__name__)
    if model_guess is None:
        # Some useful numbers for later
        x1, x2 = np.min(xin), np.max(xin)
        y1, y2 = np.min(yin), np.max(yin)
        pixel_range = 0.5 * max(x2 - x1, y2 - y1)

        # Set up translation part of the model
        if hasattr(translation, '__len__'):
            xoff, yoff = translation
        else:
            xoff, yoff = translation, translation
        trange = np.array(translation_range)
        if len(trange.shape) == 2:
            xvalue, xrange = _get_value_and_range(xoff, trange[0])
            yvalue, yrange = _get_value_and_range(yoff, trange[1])
        else:
            xvalue, xrange = _get_value_and_range(xoff, translation_range)
            yvalue, yrange = _get_value_and_range(yoff, translation_range)
        if xvalue is None or yvalue is None:
            trans_model = None
        else:
            trans_model = Shift2D(xvalue, yvalue)
            if xrange is None:
                trans_model.x_offset.fixed = True
            else:
                trans_model.x_offset.bounds = xrange
            if yrange is None:
                trans_model.y_offset.fixed = True
            else:
                trans_model.y_offset.bounds = yrange

        # Set up rotation part of the model
        rvalue, rrange = _get_value_and_range(rotation, rotation_range)
        if rvalue is None:
            rot_model = None
        else:
            # Getting the rotation wrong by da (degrees) will cause a shift of
            # da/57.3*pixel_range at the edge of the data, so we want
            # da=tolerance*57.3/pixel_range
            rot_scaling = pixel_range / 57.3
            rot_model = Rotate2D(rvalue * rot_scaling, angle_scale=rot_scaling)
            if rrange is None:
                rot_model.angle.fixed = True
            else:
                rot_model.angle.bounds = tuple(x * rot_scaling for x in rrange)

        # Set up magnification part of the model
        mvalue, mrange = _get_value_and_range(magnification,
                                              magnification_range)
        if mvalue is None:
            mag_model = None
        else:
            # Getting the magnification wrong by dm will cause a shift of
            # dm*pixel_range at the edge of the data, so we want
            # dm=tolerance/pixel_range
            mag_scaling = pixel_range
            mag_model = Scale2D(mvalue * mag_scaling, factor_scale=mag_scaling)
            if mrange is None:
                mag_model.factor.fixed = True
            else:
                mag_model.factor.bounds = tuple(x * mag_scaling
                                                for x in mrange)

        # Make the compound model
        if rot_model is None and mag_model is None:
            if trans_model is None:
                return models.Identity(2)  # Nothing to do
            else:
                init_model = trans_model  # Don't need center of field
        else:
            if center_of_field is None:
                center_of_field = (0.5 * (x1 + x2), 0.5 * (y1 + y2))
                log.debug('No center of field given, using x={:.2f} '
                          'y={:.2f}'.format(*center_of_field))
            restore = Shift2D(*center_of_field).rename('Centering')
            restore.x_offset.fixed = True
            restore.y_offset.fixed = True

            init_model = restore.inverse
            if trans_model is not None:
                init_model |= trans_model
            if rot_model is not None:
                init_model |= rot_model
            if mag_model is not None:
                init_model |= mag_model
            init_model |= restore
    elif model_guess.fittable:
        init_model = model_guess
    else:
        log.warning('The transformation is not fittable!')
        return models.Identity(2)

    final_model = fit_brute_then_simplex(init_model, (xin, yin), (xref, yref),
                                         sigma=10.0,
                                         tolerance=tolerance)
    return final_model
Exemple #4
0
def lrs(input_model, reference_files):
    """
    Create the WCS pipeline for a MIRI fixed slit observation.

    Parameters
    ----------
    input_model : `jwst.datamodels.ImagingModel`
        Data model.
    reference_files : dict
        Dictionary {reftype: reference file name}.

    reference_files = {
        "specwcs": 'MIRI_FM_MIRIMAGE_P750L_DISTORTION_04.02.00.fits'
    }
    """

    # Setup the frames.
    detector = cf.Frame2D(name='detector',
                          axes_order=(0, 1),
                          unit=(u.pix, u.pix))
    spec = cf.SpectralFrame(name='wavelength',
                            axes_order=(2, ),
                            unit=(u.micron, ),
                            axes_names=('lambda', ))
    sky = cf.CelestialFrame(reference_frame=coord.ICRS(), name='sky')
    world = cf.CompositeFrame(name="world", frames=[sky, spec])

    # Determine the distortion model.
    subarray2full = subarray_transform(input_model)
    with DistortionModel(reference_files['distortion']) as dist:
        distortion = dist.model
    # Distortion is in arcsec.  Convert to degrees
    full_distortion = subarray2full | distortion | models.Scale(
        1 / 3600.) & models.Scale(1 / 3600.)

    # Load and process the reference data.
    with fits.open(reference_files['specwcs']) as ref:
        lrsdata = np.array([l for l in ref[1].data])

        # Get the zero point from the reference data.
        # The zero_point is X, Y  (which should be COLUMN, ROW)
        # TODO: Are imx, imy 0- or 1-indexed?  We are treating them here as
        # 0-indexed.  Since they are FITS, they are probably 1-indexed.
        if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
            zero_point = ref[1].header['imx'], ref[1].header['imy']
        elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
            #zero_point = ref[1].header['imxsltl'], ref[1].header['imysltl']
            zero_point = [35, 442]  # [35, 763] # account for subarray

    # Create the bounding_box
    x0 = lrsdata[:, 3]
    y0 = lrsdata[:, 4]
    x1 = lrsdata[:, 5]

    bb = ((x0.min() - 0.5 + zero_point[0], x1.max() + 0.5 + zero_point[0]),
          (y0.min() - 0.5 + zero_point[1], y0.max() + 0.5 + zero_point[1]))
    # Find the ROW of the zero point which should be the [1] of zero_point
    row_zero_point = zero_point[1]

    # Compute the v2v3 to sky.
    tel2sky = pointing.v23tosky(input_model)

    # Compute the V2/V3 for each pixel in this row
    # x.shape will be something like (1, 388)
    y, x = np.mgrid[row_zero_point:row_zero_point + 1,
                    0:input_model.data.shape[1]]

    spatial_transform = full_distortion | tel2sky
    radec = np.array(spatial_transform(x, y))[:, 0, :]

    ra_full = np.matlib.repmat(radec[0],
                               _toindex(bb[1][1]) + 1 - _toindex(bb[1][0]), 1)
    dec_full = np.matlib.repmat(radec[1],
                                _toindex(bb[1][1]) + 1 - _toindex(bb[1][0]), 1)

    ra_t2d = models.Tabular2D(lookup_table=ra_full,
                              name='xtable',
                              bounds_error=False,
                              fill_value=np.nan)
    dec_t2d = models.Tabular2D(lookup_table=dec_full,
                               name='ytable',
                               bounds_error=False,
                               fill_value=np.nan)

    # Create the model transforms.
    lrs_wav_model = jwmodels.LRSWavelength(lrsdata, zero_point)

    # Incorporate the small rotation
    angle = np.arctan(0.00421924)
    rot = models.Rotation2D(angle)
    radec_t2d = ra_t2d & dec_t2d | rot

    # Account for the subarray when computing spatial coordinates.
    xshift = -bb[0][0]
    yshift = -bb[1][0]
    det2world = models.Mapping((1, 0, 1, 0, 0, 1)) | models.Shift(yshift, name='yshift1') & \
              models.Shift(xshift, name='xshift1') & \
              models.Shift(yshift, name='yshift2') & models.Shift(xshift, name='xshift2') & \
              models.Identity(2) | radec_t2d & lrs_wav_model
    det2world.bounding_box = bb[::-1]
    # Now the actual pipeline.
    pipeline = [(detector, det2world), (world, None)]

    return pipeline
Exemple #5
0
def test_ifu_bbox():
    bbox = {0: ((122.0908542999878, 1586.2584665188083),
                (773.5411133037417, 825.1150258966278)),
            1: ((140.3793485788431, 1606.8904629423566),
                (1190.353197027459, 1243.0853605832503)),
            2: ((120.0139534379125, 1583.9271768905855),
                (724.3249534782219, 775.8104288584977)),
            3: ((142.50252648927454, 1609.3106221382388),
                (1239.4122720740888, 1292.288713688988)),
            4: ((117.88884113088403, 1581.5517394150106),
                (674.9787657901347, 726.3752061973377)),
            5: ((144.57465414462143, 1611.688447569682),
                (1288.4808318659427, 1341.5035313084197)),
            6: ((115.8602297714846, 1579.27471654949),
                (625.7982466386104, 677.1147840452901)),
            7: ((146.7944728147906, 1614.2161842198498),
                (1337.531525654835, 1390.7050687363856)),
            8: ((113.86384530944383, 1577.0293086386203),
                (576.5344359685643, 627.777022204828)),
            9: ((149.0259581360621, 1616.7687282225652),
                (1386.5118806905086, 1439.843598490326)),
            10: ((111.91564190274217, 1574.8351095461135),
                 (527.229828693075, 578.402894851317)),
            11: ((151.3053466801954, 1619.3720722471498),
                 (1435.423685040875, 1488.917203728964)),
            12: ((109.8957204607345, 1572.570246400894),
                 (477.9699083444277, 529.0782087498488)),
            13: ((153.5023503173659, 1621.9005029476564),
                 (1484.38405923062, 1538.0443479389924)),
            14: ((107.98320121613297, 1570.411787034636),
                 (428.6704834494425, 479.7217241891257)),
            15: ((155.77991404913857, 1624.5184927460925),
                 (1533.169633314481, 1586.9984359105376)),
            16: ((106.10212081215678, 1568.286103827344),
                 (379.3860245240618, 430.3780648366697)),
            17: ((158.23149941845386, 1627.305849064835),
                 (1582.0496119714928, 1636.0513450787032)),
            18: ((104.09366374413436, 1566.030231370944),
                 (330.0822744105267, 381.01974582564395)),
            19: ((160.4511021152353, 1629.888830991371),
                 (1630.7797743277185, 1684.9592727079018)),
            20: ((102.25220592881234, 1563.9475099032868),
                 (280.7233309522168, 331.6093009077988)),
            21: ((162.72784286205734, 1632.5257403739463),
                 (1679.6815760587567, 1734.03692957156)),
            22: ((100.40115742738622, 1561.8476640376036),
                 (231.35443588323855, 282.19575854747006)),
            23: ((165.05939163941662, 1635.2270773628682),
                 (1728.511467615387, 1783.0485841263735)),
            24: ((98.45723949658425, 1559.6499479349648),
                 (182.0417295679079, 232.83530870639865)),
            25: ((167.44628840053574, 1637.9923229870349),
                 (1777.2512197664128, 1831.971115503598)),
            26: ((96.56508092457855, 1557.5079027818058),
                 (132.5285162704088, 183.27350269292484)),
            27: ((169.8529496136358, 1640.778485168005),
                 (1826.028691168028, 1880.9336718824313)),
            28: ((94.71390837793813, 1555.4048050512263),
                 (82.94691422559131, 133.63901517357235)),
            29: ((172.3681094850081, 1643.685604697228),
                 (1874.8184744639657, 1929.9072657798927))}

    hdul = create_nirspec_ifu_file("F290LP", "G140M")
    im = datamodels.IFUImageModel(hdul)
    im.meta.filename = "test_ifu.fits"
    refs = create_reference_files(im)

    pipe = nirspec.create_pipeline(im, refs, slit_y_range=[-.5, .5])
    w = wcs.WCS(pipe)
    im.meta.wcs = w

    _, wrange = nirspec.spectral_order_wrange_from_model(im)
    pipe = im.meta.wcs.pipeline

    g2s = pipe[2].transform
    transforms = [pipe[0].transform]
    transforms.append(pipe[1].transform[1:])
    transforms.append(astmodels.Identity(1))
    transforms.append(astmodels.Identity(1))
    transforms.extend([step.transform for step in pipe[4:-1]])

    for sl in range(30):
        transforms[2] = g2s.get_model(sl)
        m = functools.reduce(lambda x, y: x | y, [tr.inverse for tr in transforms[:3][::-1]])
        bbox_sl = nirspec.compute_bounding_box(m, wrange)
        assert_allclose(bbox[sl], bbox_sl)
Exemple #6
0
def prepare_psf_model(psfmodel, xname=None, yname=None, fluxname=None,
                      renormalize_psf=True):
    """
    Convert a 2D PSF model to one suitable for use with
    `BasicPSFPhotometry` or its subclasses.

    The resulting model may be a composite model, but should have only
    the x, y, and flux related parameters un-fixed.

    Parameters
    ----------
    psfmodel : a 2D model
        The model to assume as representative of the PSF.
    xname : str or None
        The name of the ``psfmodel`` parameter that corresponds to the
        x-axis center of the PSF.  If None, the model will be assumed to
        be centered at x=0, and a new parameter will be added for the
        offset.
    yname : str or None
        The name of the ``psfmodel`` parameter that corresponds to the
        y-axis center of the PSF.  If None, the model will be assumed to
        be centered at x=0, and a new parameter will be added for the
        offset.
    fluxname : str or None
        The name of the ``psfmodel`` parameter that corresponds to the
        total flux of the star.  If None, a scaling factor will be added
        to the model.
    renormalize_psf : bool
        If True, the model will be integrated from -inf to inf and
        re-scaled so that the total integrates to 1.  Note that this
        renormalization only occurs *once*, so if the total flux of
        ``psfmodel`` depends on position, this will *not* be correct.

    Returns
    -------
    outmod : a model
        A new model ready to be passed into `BasicPSFPhotometry` or its
        subclasses.
    """

    if xname is None:
        xinmod = models.Shift(0, name='x_offset')
        xname = 'offset_0'
    else:
        xinmod = models.Identity(1)
        xname = xname + '_2'
    xinmod.fittable = True

    if yname is None:
        yinmod = models.Shift(0, name='y_offset')
        yname = 'offset_1'
    else:
        yinmod = models.Identity(1)
        yname = yname + '_2'
    yinmod.fittable = True

    outmod = (xinmod & yinmod) | psfmodel

    if fluxname is None:
        outmod = outmod * models.Const2D(1, name='flux_scaling')
        fluxname = 'amplitude_3'
    else:
        fluxname = fluxname + '_2'

    if renormalize_psf:
        # we do the import here because other machinery works w/o scipy
        from scipy import integrate

        integrand = integrate.dblquad(psfmodel, -np.inf, np.inf,
                                      lambda x: -np.inf, lambda x: np.inf)[0]
        normmod = models.Const2D(1./integrand, name='renormalize_scaling')
        outmod = outmod * normmod

    # final setup of the output model - fix all the non-offset/scale
    # parameters
    for pnm in outmod.param_names:
        outmod.fixed[pnm] = pnm not in (xname, yname, fluxname)

    # and set the names so that BasicPSFPhotometry knows what to do
    outmod.xname = xname
    outmod.yname = yname
    outmod.fluxname = fluxname

    # now some convenience aliases if reasonable
    outmod.psfmodel = outmod[2]
    if 'x_0' not in outmod.param_names and 'y_0' not in outmod.param_names:
        outmod.x_0 = getattr(outmod, xname)
        outmod.y_0 = getattr(outmod, yname)
    if 'flux' not in outmod.param_names:
        outmod.flux = getattr(outmod, fluxname)

    return outmod
Exemple #7
0
def fitswcs_image(header):
    """
    Make a complete transform from CRPIX-shifted pixels to
    sky coordinates from FITS WCS keywords. A Mapping is inserted
    at the beginning, which may be removed later

    Parameters
    ----------
    header : `astropy.io.fits.Header` or dict
        FITS Header or dict with basic FITS WCS keywords.

    """
    if isinstance(header, fits.Header):
        wcs_info = read_wcs_from_header(header)
    elif isinstance(header, dict):
        wcs_info = header
    else:
        raise TypeError("Expected a FITS Header or a dict.")

    crpix = wcs_info['CRPIX']
    cd = wcs_info['CD']
    # get the part of the PC matrix corresponding to the imaging axes
    sky_axes, spec_axes, unknown = get_axes(wcs_info)
    if not sky_axes:
        if len(unknown) == 2:
            sky_axes = unknown
        else:  # No sky here
            return
    pixel_axes = _get_contributing_axes(wcs_info, sky_axes)
    if len(pixel_axes) > 2:
        raise ValueError(
            "More than 2 pixel axes contribute to the sky coordinates")

    translation_models = [
        models.Shift(-(crpix[i] - 1), name='crpix' + str(i + 1))
        for i in pixel_axes
    ]
    translation = functools.reduce(lambda x, y: x & y, translation_models)
    transforms = [translation]

    # If only one axis is contributing to the sky (e.g., slit spectrum)
    # then it must be that there's an extra axis in the CD matrix, so we
    # create a "ghost" orthogonal axis here so an inverse can be defined
    # Modify the CD matrix in case we have to use a backup Matrix Model later
    if len(pixel_axes) == 1:
        cd[sky_axes[0], -1] = -cd[sky_axes[1], pixel_axes[0]]
        cd[sky_axes[1], -1] = cd[sky_axes[0], pixel_axes[0]]
        sky_cd = cd[np.ix_(sky_axes, pixel_axes + [-1])]
        affine = models.AffineTransformation2D(matrix=sky_cd, name='cd_matrix')
        # TODO: replace when PR#10362 is in astropy
        #rotation = models.fix_inputs(affine, {'y': 0})
        rotation = models.Mapping(
            (0, 0)) | models.Identity(1) & models.Const1D(0) | affine
        rotation.inverse = affine.inverse | models.Mapping((0, ), n_inputs=2)
    else:
        sky_cd = cd[np.ix_(sky_axes, pixel_axes)]
        rotation = models.AffineTransformation2D(matrix=sky_cd,
                                                 name='cd_matrix')
    transforms.append(rotation)

    projection = gwutils.fitswcs_nonlinear(wcs_info)
    if projection:
        transforms.append(projection)

    sky_model = functools.reduce(lambda x, y: x | y, transforms)
    sky_model.name = 'SKY'
    sky_model.meta.update({'input_axes': pixel_axes, 'output_axes': sky_axes})
    return sky_model
Exemple #8
0
def gwcs_to_fits(ndd, hdr=None):
    """
    Convert a gWCS object to a collection of FITS WCS keyword/value pairs,
    if possible. If the FITS WCS is only approximate, this should be indicated
    with a dict entry {'FITS-WCS': 'APPROXIMATE'}. If there is no suitable
    FITS representation, then a ValueError or NotImplementedError can be
    raised.

    Parameters
    ----------
    ndd : `astropy.nddata.NDData`
        The NDData whose wcs attribute we want converted
    hdr : `astropy.io.fits.Header`
        A Header object that may contain some useful keywords

    Returns
    -------
    dict
        values to insert into the FITS header to express this WCS

    """
    if hdr is None:
        hdr = {}

    wcs = ndd.wcs
    transform = wcs.forward_transform
    world_axes = list(wcs.output_frame.axes_names)
    nworld_axes = len(world_axes)
    wcs_dict = {'WCSAXES': nworld_axes, 'WCSDIM': nworld_axes}
    wcs_dict.update({
        f'CD{i+1}_{j+1}': 0.
        for j in range(nworld_axes) for i in range(nworld_axes)
    })
    pix_center = [0.5 * (length - 1) for length in ndd.shape[::-1]]
    wcs_center = transform(*pix_center)

    # Find and process the sky projection first
    if {'lon', 'lat'}.issubset(world_axes):
        if isinstance(wcs.output_frame, cf.CelestialFrame):
            cel_frame = wcs.output_frame
        elif isinstance(wcs.output_frame, cf.CompositeFrame):
            for frame in wcs.output_frame.frames:
                if isinstance(frame, cf.CelestialFrame):
                    cel_frame = frame

        # TODO: Non-ecliptic coordinate frames
        cel_ref_frame = cel_frame.reference_frame
        if not isinstance(cel_ref_frame, coord.builtin_frames.BaseRADecFrame):
            raise NotImplementedError("Cannot write non-ecliptic frames yet")
        wcs_dict['RADESYS'] = cel_ref_frame.name.upper()

        for m in transform:
            if isinstance(m, models.RotateNative2Celestial):
                nat2cel = m
            if isinstance(m, models.Pix2SkyProjection):
                m.name = 'pix2sky'
                # Determine which sort of projection this is
                for projcode in projections.projcodes:
                    if isinstance(m, getattr(models, f'Pix2Sky_{projcode}')):
                        break
                else:
                    raise ValueError("Unknown projection class: {}".format(
                        m.__class__.__name__))

        lon_axis = world_axes.index('lon')
        lat_axis = world_axes.index('lat')
        world_axes[lon_axis] = f'RA---{projcode}'
        world_axes[lat_axis] = f'DEC--{projcode}'
        wcs_dict[f'CRVAL{lon_axis+1}'] = nat2cel.lon.value
        wcs_dict[f'CRVAL{lat_axis+1}'] = nat2cel.lat.value

        # Remove projection parts so we can calculate the CD matrix
        if projcode:
            nat2cel.name = 'nat2cel'
            transform = transform.replace_submodel('pix2sky',
                                                   models.Identity(2))
            transform = transform.replace_submodel('nat2cel',
                                                   models.Identity(2))

    # Deal with other axes
    # TODO: AD should refactor to allow the descriptor to be used here
    for i, axis_type in enumerate(wcs.output_frame.axes_type, start=1):
        if f'CRVAL{i}' in wcs_dict:
            continue
        if axis_type == "SPECTRAL":
            wcs_dict[f'CRVAL{i}'] = hdr.get(
                'CENTWAVE',
                wcs_center[i - 1] if nworld_axes > 1 else wcs_center)
            wcs_dict[f'CTYPE{i}'] = wcs.output_frame.axes_names[i -
                                                                1]  # AWAV/WAVE
        else:  # Just something
            wcs_dict[f'CRVAL{i}'] = 0

    # Flag if we can't construct a perfect CD matrix
    if not model_is_affine(transform):
        wcs_dict['FITS-WCS'] = ('APPROXIMATE', 'FITS WCS is approximate')

    affine = calculate_affine_matrices(transform, ndd.shape)
    # Convert to x-first order
    affine_matrix = affine.matrix[::-1, ::-1]
    # Require an inverse to write out
    if np.linalg.det(affine_matrix) == 0:
        affine_matrix[-1, -1] = 1.
    wcs_dict.update({
        f'CD{i+1}_{j+1}': affine_matrix[i, j]
        for j, _ in enumerate(world_axes) for i, _ in enumerate(world_axes)
    })
    # Don't overwrite CTYPEi keywords we've already created
    wcs_dict.update({
        f'CTYPE{i}': axis.upper()[:8]
        for i, axis in enumerate(world_axes, start=1)
        if f'CTYPE{i}' not in wcs_dict
    })

    crval = [wcs_dict[f'CRVAL{i+1}'] for i, _ in enumerate(world_axes)]
    crpix = np.array(wcs.backward_transform(*crval)) + 1
    if nworld_axes == 1:
        wcs_dict['CRPIX1'] = crpix
    else:
        # Comply with FITS standard, must define CRPIXj for "extra" axes
        wcs_dict.update({
            f'CRPIX{j}': cpix
            for j, cpix in enumerate(np.concatenate(
                [crpix, [0] * (nworld_axes - len(ndd.shape))]),
                                     start=1)
        })
    for i, unit in enumerate(wcs.output_frame.unit, start=1):
        try:
            wcs_dict[f'CUNIT{i}'] = unit.name
        except AttributeError:
            pass

    return wcs_dict
Exemple #9
0
def compute_footprint_nrs_ifu(dmodel, mod):
    """
    Determine NIRSPEC IFU footprint using the instrument model.

    For efficiency this function uses the transforms directly,
    instead of the WCS object. The common transforms in the WCS
    model chain are referenced and reused; only the slice specific
    transforms are computed.

    If the transforms change this function should be revised.

    Parameters
    ----------
    output_model : `~jwst.datamodels.IFUImageModel`
        The output of assign_wcs.
    mod : module
        The imported ``nirspec`` module.

    Returns
    -------
    footprint : ndarray
        The spatial footprint
    spectral_region : tuple
        The wavelength range for the observation.
    """
    ra_total = []
    dec_total = []
    lam_total = []
    _, wrange = mod.spectral_order_wrange_from_model(dmodel)
    pipe = dmodel.meta.wcs.pipeline

    # Get the GWA to slit_frame transform
    g2s = pipe[2].transform

    # Construct a list of the transforms between coordinate frames.
    # Set a place holder ``Identity`` transform at index 2 and 3.
    # Update them with slice specific transforms.
    transforms = [pipe[0].transform]
    transforms.append(pipe[1].transform[1:])
    transforms.append(astmodels.Identity(1))
    transforms.append(astmodels.Identity(1))
    transforms.extend([step.transform for step in pipe[4:-1]])

    for sl in range(30):
        transforms[2] = g2s.get_model(sl)
        # Create the full transform from ``slit_frame`` to ``detector``.
        # It is used to compute the bounding box.
        m = functools.reduce(lambda x, y: x | y,
                             [tr.inverse for tr in transforms[:3][::-1]])
        bbox = mod.compute_bounding_box(m, wrange)
        # Add the remaining transforms - from ``sli_frame`` to ``world``
        transforms[3] = pipe[3].transform.get_model(sl) & astmodels.Identity(1)
        mforw = functools.reduce(lambda x, y: x | y, transforms)
        x1, y1 = grid_from_bounding_box(bbox)
        ra, dec, lam = mforw(x1, y1)
        ra_total.extend(np.ravel(ra))
        dec_total.extend(np.ravel(dec))
        lam_total.extend(np.ravel(lam))
    # the wrapped ra values are forced to be on one side of ra-border
    # the wrapped ra are used to determine the correct  min and max ra
    ra_total = wrap_ra(ra_total)
    ra_max = np.nanmax(ra_total)
    ra_min = np.nanmin(ra_total)
    # for the footprint we want ra to be between 0 to 360
    if (ra_min < 0):
        ra_min = ra_min + 360.0
    if (ra_max >= 360.0):
        ra_max = ra_max - 360.0

    dec_max = np.nanmax(dec_total)
    dec_min = np.nanmin(dec_total)
    lam_max = np.nanmax(lam_total)
    lam_min = np.nanmin(lam_total)
    footprint = np.array(
        [ra_min, dec_min, ra_max, dec_min, ra_max, dec_max, ra_min, dec_max])
    return footprint, (lam_min, lam_max)