Exemple #1
0
def wcs_from_spec_footprints(wcslist,
                             refwcs=None,
                             transform=None,
                             domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError(
            "All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")

    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame,
               input_frame=input_frame,
               forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({
            'lower': axis_min,
            'upper': axis_max,
            'includes_lower': True,
            'includes_upper': False
        })
    wnew.domain = domain
    return wnew
def wcs_from_spec_footprints(wcslist, refwcs=None, transform=None, domain=None):
    """
    Create a WCS from a list of spatial/spectral WCS.

    Build-7 workaround.
    """
    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of gwcs.WCS")
    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError("All items in 'wcslist' must have instance of gwcs.WCS")
    if refwcs is None:
        refwcs = wcslist[0]
    else:
        if not isinstance(refwcs, WCS):
            raise TypeError("Expected refwcs to be an instance of gwcs.WCS.")
    
    # TODO: generalize an approach to do this for more than one wcs.  For
    # now, we just do it for one, using the api for a list of wcs.
    # Compute a fiducial point for the output frame at center of input data
    fiducial = compute_spec_fiducial(wcslist, domain=domain)
    # Create transform for output frame
    transform = compute_spec_transform(fiducial, refwcs)
    output_frame = refwcs.output_frame
    wnew = WCS(output_frame=output_frame, forward_transform=transform)

    # Build the domain in the output frame wcs object by running the input wcs
    # footprints through the backward transform of the output wcs
    sky = [spec_footprint(w) for w in wcslist]
    domain_grid = [wnew.backward_transform(*f) for f in sky]

    sky0 = sky[0]
    det = domain_grid[0]
    offsets = []
    input_frame = refwcs.input_frame
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(det[axis])
        offsets.append(axis_min)
    transform = Shift(offsets[0]) & Shift(offsets[1]) | transform
    wnew = WCS(output_frame=output_frame, input_frame=input_frame,
        forward_transform=transform)

    domain = []
    for axis in input_frame.axes_order:
        axis_min = np.nanmin(domain_grid[0][axis])
        axis_max = np.nanmax(domain_grid[0][axis]) + 1
        domain.append({'lower': axis_min, 'upper': axis_max,
            'includes_lower': True, 'includes_upper': False})
    wnew.domain = domain
    return wnew
Exemple #3
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from bounding_box
        dx0 = refwcs.bounding_box[0][0]
        dx1 = refwcs.bounding_box[0][1]
        dy0 = refwcs.bounding_box[1][0]
        dy1 = refwcs.bounding_box[1][1]
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix,
                                       dy0 + dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0],
                                   dec=slit_center[1],
                                   unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)

        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #4
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within bounding box
        bb = refwcs.bounding_box
        det = x, y = wcstools.grid_from_bounding_box(bb,
                                                     step=(1, 1),
                                                     center=True)
        sky = ra, dec, lam = refwcs(*det)
        x_center = int((bb[0][1] - bb[0][0]) / 2)
        y_center = int((bb[1][1] - bb[1][0]) / 2)
        log.debug("Center of bounding box: {}  {}".format(x_center, y_center))

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                                         x[y_center,
                                           x_center - sz + 1:x_center + sz],
                                         bounds_error=False,
                                         fill_value='extrapolate')
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:, x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
                                   dec=interpol_dec(slit_center_pix),
                                   unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(
            resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        slit_rot_angle = slit_rot_angle.value
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)

        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale.value) & Scale(spatial_scale.value)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra.value,
                                        slit_center_sky.dec.value,
                                        lon_pole.value)

        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)
        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #5
0
    def build_miri_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        x, y = wcstools.grid_from_domain(refwcs.domain)
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from domain
        dx0 = refwcs.domain[0]['lower']
        dx1 = refwcs.domain[0]['upper']
        dy0 = refwcs.domain[1]['lower']
        dy1 = refwcs.domain[1]['upper']
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix, dy0 +
            dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0], dec=slit_center[1],
            unit=u.deg)
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(ra, dec, lam)

        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[0]['lower'], domain[0]['upper']))
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #6
0
    def build_nirspec_output_wcs(self, refwcs=None):
        """
        Create a simple output wcs covering footprint of the input datamodels
        """
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within domain
        det = x, y = wcstools.grid_from_domain(refwcs.domain)
        sky = ra, dec, lam = refwcs(*det)
        domain_xsize = refwcs.domain[0]['upper'] - refwcs.domain[0]['lower']
        domain_ysize = refwcs.domain[1]['upper'] - refwcs.domain[1]['lower']
        x_center, y_center = int(domain_xsize / 2), int(domain_ysize / 2)

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                xpos.append(np.nan)
            else:
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                    x[y_center, x_center - sz + 1:x_center + sz])
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:,x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
            dec=interpol_dec(slit_center_pix), unit=u.deg)
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(resample_utils.compute_spec_fiducial([refwcs])))
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale) & Scale(spatial_scale)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra, slit_center_sky.dec,
            lon_pole)
        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the domain in the output frame wcs object
        domain_grid = wnew.backward_transform(*sky)
        domain = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(domain_grid[axis])
            axis_max = np.nanmax(domain_grid[axis]) + 1
            domain.append({'lower': axis_min, 'upper': axis_max,
                'includes_lower': True, 'includes_upper': False})
        log.debug('Domain: {0} {1}'.format(domain[1]['lower'], domain[1]['upper']))
        wnew.domain = domain

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #7
0
def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None):
    """
    Create a WCS from a list of input data models.

    A fiducial point in the output coordinate frame is created from  the
    footprints of all WCS objects. For a spatial frame this is the center
    of the union of the footprints. For a spectral frame the fiducial is in
    the beginning of the footprint range.
    If ``refmodel`` is None, the first WCS object in the list is considered
    a reference. The output coordinate frame and projection (for celestial frames)
    is taken from ``refmodel``.
    If ``transform`` is not suplied, a compound transform is created using
    CDELTs and PC.
    If ``bounding_box`` is not supplied, the bounding_box of the new WCS is computed
    from bounding_box of all input WCSs.

    Parameters
    ----------
    dmodels : list of `~jwst.datamodels.DataModel`
        A list of data models.
    refmodel : `~jwst.datamodels.DataModel`, optional
        This model's WCS is used as a reference.
        WCS. The output coordinate frame, the projection and a
        scaling and rotation transform is created from it. If not supplied
        the first model in the list is used as ``refmodel``.
    transform : `~astropy.modeling.core.Model`, optional
        A transform, passed to :meth:`~gwcs.wcstools.wcs_from_fiducial`
        If not supplied Scaling | Rotation is computed from ``refmodel``.
    bounding_box : tuple, optional
        Bounding_box of the new WCS.
        If not supplied it is computed from the bounding_box of all inputs.
    """
    bb = bounding_box
    wcslist = [im.meta.wcs for im in dmodels]

    if not isiterable(wcslist):
        raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.")

    if not all([isinstance(w, WCS) for w in wcslist]):
        raise TypeError("All items in wcslist are to be instances of gwcs.WCS.")

    if refmodel is None:
        refmodel = dmodels[0]
    else:
        if not isinstance(refmodel, DataModel):
            raise TypeError("Expected refmodel to be an instance of DataModel.")

    fiducial = compute_fiducial(wcslist, bb)
    ref_fiducial = compute_fiducial([refmodel.meta.wcs])

    prj = astmodels.Pix2Sky_TAN()


    if transform is None:
        transform = []
        wcsinfo = pointing.wcsinfo_from_model(refmodel)
        sky_axes, spec, other = gwutils.get_axes(wcsinfo)

        # Need to put the rotation matrix (List[float, float, float, float]) returned from calc_rotation_matrix into the
        # correct shape for constructing the transformation
        pc = np.reshape(
            calc_rotation_matrix(
                np.deg2rad(refmodel.meta.wcsinfo.roll_ref),
                np.deg2rad(refmodel.meta.wcsinfo.v3yangle),
                vparity=refmodel.meta.wcsinfo.vparity),
            (2, 2)
        )

        rotation = astmodels.AffineTransformation2D(pc)
        transform.append(rotation)

        if sky_axes:
            scale = compute_scale(refmodel.meta.wcs, ref_fiducial)
            transform.append(astmodels.Scale(scale) & astmodels.Scale(scale))

        if transform:
            transform = functools.reduce(lambda x, y: x | y, transform)

    out_frame = refmodel.meta.wcs.output_frame
    input_frame = dmodels[0].meta.wcs.input_frame
    wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj, transform=transform)
    # temporary fix before gwcs 0.14 is released
    # wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj,
    #                          transform=transform, input_frame=input_frame)
    pipe = wnew.pipeline[:]
    pipe[0] = (input_frame, pipe[0][1])
    wnew = WCS(pipe)

    footprints = [w.footprint().T for w in wcslist]
    domain_bounds = np.hstack([wnew.backward_transform(*f) for f in footprints])

    for axs in domain_bounds:
        axs -= (axs.min()  + .5)

    bounding_box = []
    for axis in out_frame.axes_order:
        axis_min, axis_max = domain_bounds[axis].min(), domain_bounds[axis].max()
        bounding_box.append((axis_min, axis_max))

    bounding_box = tuple(bounding_box)
    ax1, ax2 = np.array(bounding_box)[sky_axes]
    offset1 = (ax1[1] - ax1[0]) / 2
    offset2 = (ax2[1] - ax2[0]) / 2
    offsets = astmodels.Shift(-offset1) & astmodels.Shift(-offset2)

    wnew.insert_transform('detector', offsets, after=True)
    wnew.bounding_box = bounding_box

    return wnew