Ejemplo n.º 1
0
def wcs_from_spec_footprints(wcslist,
                             refwcs=None,
                             transform=None,
                             domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError(
            "All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")

    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame,
               input_frame=input_frame,
               forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({
            'lower': axis_min,
            'upper': axis_max,
            'includes_lower': True,
            'includes_upper': False
        })
    wnew.domain = domain
    return wnew
Ejemplo n.º 2
0
def wcs_from_spec_footprints(wcslist, refwcs=None, transform=None, domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError("All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")
    
    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame, input_frame=input_frame,
        forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({'lower': axis_min, 'upper': axis_max,
            'includes_lower': True, 'includes_upper': False})
    wnew.domain = domain
    return wnew
Ejemplo n.º 3
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        x, y = wcstools.grid_from_domain(refwcs.domain)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from domain
        dx0 = refwcs.domain[0]['lower']
        dx1 = refwcs.domain[0]['upper']
        dy0 = refwcs.domain[1]['lower']
        dy1 = refwcs.domain[1]['upper']
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 +
            dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1],
            unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(ra, dec, lam)

        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[0]['lower'], domain[0]['upper']))
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Ejemplo n.º 4
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        det = x, y = wcstools.grid_from_domain(refwcs.domain)
        sky = ra, dec, lam = refwcs(*det)
        domain_xsize = refwcs.domain[0]['upper'] - refwcs.domain[0]['lower']
        domain_ysize = refwcs.domain[1]['upper'] - refwcs.domain[1]['lower']
        x_center, y_center = int(domain_xsize / 2), int(domain_ysize / 2)

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                    x[y_center, x_center - sz + 1:x_center + sz])
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:,x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
            dec=interpol_dec(slit_center_pix), unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale) & Scale(spatial_scale)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(*sky)
        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Ejemplo n.º 5
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        x, y = wcstools.grid_from_domain(refwcs.domain)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from domain
        dx0 = refwcs.domain[0]['lower']
        dx1 = refwcs.domain[0]['upper']
        dy0 = refwcs.domain[1]['lower']
        dy1 = refwcs.domain[1]['upper']
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix,
                                       dy0 + dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0],
                                   dec=slit_center[1],
                                   unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(ra, dec, lam)

        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({
                'lower': axis_min,
                'upper': axis_max,
                'includes_lower': True,
                'includes_upper': False
            })
        log.debug('Domain: {0} {1}'.format(domain[0]['lower'],
                                           domain[0]['upper']))
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'],
                                           domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Ejemplo n.º 6
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        det = x, y = wcstools.grid_from_domain(refwcs.domain)
        sky = ra, dec, lam = refwcs(*det)
        domain_xsize = refwcs.domain[0]['upper'] - refwcs.domain[0]['lower']
        domain_ysize = refwcs.domain[1]['upper'] - refwcs.domain[1]['lower']
        x_center, y_center = int(domain_xsize / 2), int(domain_ysize / 2)

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(
                    row[x_center - sz + 1:x_center + sz],
                    x[y_center, x_center - sz + 1:x_center + sz])
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:, x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
                                   dec=interpol_dec(slit_center_pix),
                                   unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(
            resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale) & Scale(spatial_scale)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(*sky)
        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({
                'lower': axis_min,
                'upper': axis_max,
                'includes_lower': True,
                'includes_upper': False
            })
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'],
                                           domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew