Beispiel #1
0
def _transform_corners(ads, all_corners, ref_wcs, interpolator):
    shifts = []
    xy_img_corners = []

    for ad in ads:
        img_wcs = WCS(ad[0].hdr)
        img_shape = ad[0].data.shape
        img_corners = at.get_corners(img_shape)
        xy_corners  = [(corner[1],corner[0]) for corner in img_corners]
        xy_img_corners.append(xy_corners)

        if interpolator is None:
            # find shift by transforming center position of field
            # (so that center matches best)
            x1y1 = np.array([img_shape[1]/2.0, img_shape[0]/2.0])
            x2y2 = img_wcs.all_world2pix(ref_wcs.all_pix2world([x1y1],1), 1)[0]

            # round shift to nearest integer and flip x and y
            offset = np.roll(np.rint(x2y2-x1y1),1)

            # shift corners of image
            img_corners = [tuple(offset+corner) for corner in img_corners]
            shifts.append(offset)
        else:
            # transform corners of image via WCS
            xy_corners = img_wcs.all_world2pix(ref_wcs.all_pix2world(xy_corners,0),0)
            img_corners = [(corner[1],corner[0]) for corner in xy_corners]

        all_corners.append(img_corners)
    return all_corners, xy_img_corners, shifts
Beispiel #2
0
def _transform_corners(ads, all_corners, ref_wcs, interpolator):
    shifts = []
    xy_img_corners = []

    for ad in ads:
        img_wcs = WCS(ad[0].hdr)
        img_shape = ad[0].data.shape
        img_corners = at.get_corners(img_shape)
        xy_corners = [(corner[1], corner[0]) for corner in img_corners]
        xy_img_corners.append(xy_corners)

        if interpolator is None:
            # find shift by transforming center position of field
            # (so that center matches best)
            x1y1 = np.array([img_shape[1] / 2.0, img_shape[0] / 2.0])
            x2y2 = img_wcs.all_world2pix(ref_wcs.all_pix2world([x1y1], 1),
                                         1)[0]

            # round shift to nearest integer and flip x and y
            offset = np.roll(np.rint(x2y2 - x1y1), 1)

            # shift corners of image
            img_corners = [tuple(offset + corner) for corner in img_corners]
            shifts.append(offset)
        else:
            # transform corners of image via WCS
            xy_corners = img_wcs.all_world2pix(
                ref_wcs.all_pix2world(xy_corners, 0), 0)
            img_corners = [(corner[1], corner[0]) for corner in xy_corners]

        all_corners.append(img_corners)
    return all_corners, xy_img_corners, shifts
Beispiel #3
0
def test_get_corners_3d():
    corners = at.get_corners((300, 500, 400))
    expected_corners = [(0, 0, 0), (299, 0, 0), (0, 499, 0), (299, 499, 0),
                        (0, 0, 399), (299, 0, 399), (0, 499, 399),
                        (299, 499, 399)]
    assert corners == expected_corners
Beispiel #4
0
def test_get_corners_2d():
    corners = at.get_corners((300, 500))
    assert corners == [(0, 0), (299, 0), (0, 499), (299, 499)]
Beispiel #5
0
 def test_get_corners_3d(self):
     corners = astrotools.get_corners((300, 500, 400))
     expected_corners = [(0, 0, 0), (299, 0, 0), (0, 499, 0),
                         (299, 499, 0), (0, 0, 399), (299, 0, 399),
                         (0, 499, 399), (299, 499, 399)]
     assert corners == expected_corners
Beispiel #6
0
 def test_get_corners_2d(self):
     corners = astrotools.get_corners((300, 500))
     assert corners == [(0, 0), (299, 0), (0, 499), (299, 499)]
    def determineAstrometricSolution(self, adinputs=None, **params):
        """
        This primitive determines how to modify the WCS of each image to
        produce the best positional match between its sources (OBJCAT) and
        the REFCAT. This is a GSAOI-specific version that differs from the
        core primitive in three ways: (a) it combines all the extensions to
        provide a single offset model for the entire AD object; (b) it
        inserts this model after the static distortion correction, instead
        of at the start of the WCS forward_transform; (c) the model is two
        separate Polynomial2D objects rather than a scale/shift/rotate
        transform.

        Parameters
        ----------
        initial : float
            search radius for cross-correlation (arcsec)
        final : float
            search radius for object matching (arcsec)
        rotate : bool
            allow image rotation in initial alignment of reference catalog?
        scale : bool
            allow image scaling in initial alignment of reference catalog?
        order : int
            order of polynomial fit in each ordinate
        max_iters : int
            maximum number of iterations when performing polynomial fit
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]
        rotate = params["rotate"]
        scale = params["scale"]
        initial = params["initial"]
        final = params["final"]
        order = params["order"]
        max_iters = params["max_iters"]

        self._attach_static_distortion(adinputs)
        if any(self.timestamp_keys["adjustWCSToReference"] in ad.phu
               for ad in adinputs):
            log.warning("One or more inputs has been processed by "
                        f"adjustWCSToReference. {self.myself()} will not"
                        "preserve these alignments.")

        for ad in adinputs:
            if len(ad) == 1:
                raise OSError(f"{self.myself()} must be run on unmosaicked "
                              f"GSAOI data, but {ad.filename} has been "
                              "mosaicked/tiled.")
            # Check we have a REFCAT and at least one OBJCAT to match
            try:
                refcat = ad.REFCAT
            except AttributeError:
                log.warning(
                    f"No REFCAT in {ad.filename} - cannot calculate astrometry"
                )
                continue
            if not ('RAJ2000' in refcat.colnames
                    and 'DEJ2000' in refcat.colnames):
                log.warning(
                    f"REFCAT in {ad.filename} is missing RAJ2000 "
                    "and/or DEJ2000 columns - cannot calculate astrometry")
                continue
            try:
                objcat = merge_gsaoi_objcats(ad)
            except ValueError:
                log.warning(
                    f"No OBJCATs in {ad.filename} - cannot match to REFCAT")
                continue

            if not all([
                    isinstance(getattr(ext.wcs, 'output_frame'),
                               cf.CelestialFrame) for ext in ad
            ]):
                log.warning("Missing CelestialFrame in at least one extension"
                            f" of {ad.filename}")
                continue

            # We're going to fit in the static-corrected coordinate frame
            # Find the its boundaries so we can cull the REFCAT
            static_transforms = [
                ext.wcs.get_transform(ext.wcs.input_frame, "static")
                for ext in ad
            ]
            all_corners = np.concatenate([
                np.array(static(*np.array(at.get_corners(ext.shape)).T))
                for ext, static in zip(ad, static_transforms)
            ],
                                         axis=1)
            xmin, ymin = all_corners.min(axis=1) - initial
            xmax, ymax = all_corners.max(axis=1) + initial

            # This will be the same for all extensions if the user hasn't
            # hacked it (which is something we can't really check)
            static_to_world_transform = ad[0].wcs.get_transform(
                "static", ad[0].wcs.output_frame)
            xref, yref = static_to_world_transform.inverse(
                refcat['RAJ2000'], refcat['DEJ2000'])
            #refcat["X_STATIC"], refcat["Y_STATIC"] = xref, yref
            #refcat.write(f"ref_{ad.filename}", overwrite=True)
            in_field = np.all(
                (xref > xmin, xref < xmax, yref > ymin, yref < ymax), axis=0)
            num_ref_sources = in_field.sum()
            if num_ref_sources == 0:
                log.stdinfo(f"No REFCAT sources in field of {ad.filename}")
                continue

            m_init = (models.Shift(0, bounds={'offset': (-initial, initial)})
                      & models.Shift(0, bounds={'offset':
                                                (-initial, initial)}))
            if rotate:
                m_init = am.Rotate2D(0, bounds={'angle': (-5, 5)}) | m_init
            if scale:
                m_init = am.Scale2D(1, bounds={'factor':
                                               (0.95, 1.05)}) | m_init

            # How many objects do we want to try to match? Keep brightest ones only
            objcat_len = len(objcat)
            if objcat_len > 2 * num_ref_sources:
                keep_num = max(2 * num_ref_sources, min(10, objcat_len))
            else:
                keep_num = objcat_len
            sorted_idx = np.argsort(objcat['MAG_AUTO'])[:keep_num]

            log.stdinfo(f"Aligning {ad.filename} with {num_ref_sources} REFCAT"
                        f" and {objcat_len} OBJCAT sources")
            in_coords = (objcat['X_STATIC'][sorted_idx],
                         objcat['Y_STATIC'][sorted_idx])
            ref_coords = (xref[in_field], yref[in_field])

            transform = fit_model(m_init,
                                  in_coords,
                                  ref_coords,
                                  sigma=0.2,
                                  tolerance=0.001,
                                  brute=True,
                                  scale=1 / 0.02)
            # This order will assign the closest OBJCAT to each REFCAT source
            matched = match_sources((xref, yref),
                                    transform(objcat['X_STATIC'],
                                              objcat['Y_STATIC']),
                                    radius=final)
            num_matched = np.sum(matched >= 0)
            log.stdinfo(f"Initial match: {num_matched} objects")

            if num_matched > 0:
                if num_matched > 2:
                    transform, matched = create_polynomial_transform(
                        transform, (objcat['X_STATIC'], objcat['Y_STATIC']),
                        (xref, yref),
                        order=order,
                        max_iters=max_iters,
                        match_radius=final,
                        log=self.log)
                    num_matched = np.sum(matched >= 0)
                    log.stdinfo(f"Final match: {num_matched} objects")
                else:
                    log.warning(
                        "Insufficient matches to perform distortion "
                        "correction - performing simple alignment only."
                        " Perhaps try increasing the value of "
                        f"'final' (using {final})?")

                # Associate REFCAT properties with their OBJCAT
                # counterparts. Remember! matched is the reference
                # (OBJCAT) source for the input (REFCAT) source
                dx, dy = [], []
                xtrans, ytrans = transform(objcat['X_STATIC'],
                                           objcat['Y_STATIC'])
                cospa, sinpa = math.cos(ad.phu['PA']), math.sin(ad.phu['PA'])
                xmatched = np.full((len(objcat), ), -999, dtype=float)
                ymatched = np.full((len(objcat), ), -999, dtype=float)
                for i, m in enumerate(matched):
                    if m >= 0:
                        try:
                            objcat['REF_NUMBER'][m] = refcat['Id'][i]
                        except KeyError:  # no such columns in REFCAT
                            pass
                        try:
                            objcat['REF_MAG'][m] = refcat['filtermag'][i]
                            objcat['REF_MAG_ERR'][m] = refcat['filtermag_err'][
                                i]
                        except KeyError:  # no such columns in REFCAT
                            pass
                        dx.append(xref[i] - xtrans[m])
                        dy.append(yref[i] - ytrans[m])
                        xmatched[m], ymatched[m] = xref[i], yref[i]

                #objcat["X_MATCHED"], objcat["Y_MATCHED"] = xmatched, ymatched
                #objcat["X_TRANS"], objcat['Y_TRANS'] = xtrans, ytrans
                #objcat.write(f'obj_{ad.filename}', overwrite=True)

                x0, y0 = transform(0, 0)
                delta_ra = x0 * cospa - y0 * sinpa
                delta_dec = x0 * sinpa + y0 * cospa
                dra = np.array(dx) * cospa - np.array(dy) * sinpa
                ddec = np.array(dx) * sinpa + np.array(dy) * cospa
                dra_std, ddec_std = dra.std(), ddec.std()
                log.fullinfo(f"WCS Updated for {ad.filename}. Astrometric "
                             "offset is:")
                log.fullinfo("RA: {:6.2f} +/- {:.2f} arcsec".format(
                    delta_ra, dra_std))
                log.fullinfo("Dec:{:6.2f} +/- {:.2f} arcsec".format(
                    delta_dec, ddec_std))
                info_list = [{
                    "dra": delta_ra,
                    "dra_std": dra_std,
                    "ddec": delta_dec,
                    "ddec_std": ddec_std,
                    "nsamples": int(num_matched)
                }]
                # Report the measurement to the fitsstore
                if self.upload and "metrics" in self.upload:
                    qap.fitsstore_report(ad,
                                         "pe",
                                         info_list,
                                         self.calurl_dict,
                                         self.mode,
                                         upload=True)

                # Update OBJCAT (X_WORLD, Y_WORLD)
                for index, ext in enumerate(ad):
                    # TODO: use insert_frame method
                    var_frame = cf.Frame2D(unit=(u.arcsec, u.arcsec),
                                           name="variable")
                    ext.wcs = gWCS(ext.wcs.pipeline[:1] +
                                   [(ext.wcs.pipeline[1].frame, transform),
                                    (var_frame, static_to_world_transform),
                                    (ext.wcs.output_frame, None)])
                    ext.objcat = objcat[objcat['ext_index'] == index]
                    ext.objcat['X_WORLD'], ext.objcat['Y_WORLD'] = ext.wcs(
                        ext.objcat['X_IMAGE'] - 1, ext.objcat['Y_IMAGE'] - 1)
                    ext.objcat.remove_columns(
                        ['ext_index', 'X_STATIC', 'Y_STATIC'])
            else:
                log.stdinfo("Could not determine astrometric offset for "
                            f"{ad.filename}")

            # Timestamp and update filename
            gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
            ad.update_filename(suffix=params["suffix"], strip=True)
        return adinputs
    def alignToReferenceFrame(self, rc):
        """
        This primitive applies the transformation encoded in the input images
        WCSs to align them with a reference image, in reference image pixel
        coordinates. The reference image is taken to be the first image in
        the input list.
        
        By default, the transformation into the reference frame is done via
        interpolation. The interpolator parameter specifies the interpolation 
        method. The options are nearest-neighbor, bilinear, or nth-order 
        spline, with n = 2, 3, 4, or 5. If interpolator is None, 
        no interpolation is done: the input image is shifted by an integer
        number of pixels, such that the center of the frame matches up as
        well as possible. The variance plane, if present, is transformed in
        the same way as the science data.
        
        The data quality plane, if present, must be handled a little
        differently. DQ flags are set bit-wise, such that each pixel is the 
        sum of any of the following values: 0=good pixel,
        1=bad pixel (from bad pixel mask), 2=nonlinear, 4=saturated, etc.
        To transform the DQ plane without losing flag information, it is
        unpacked into separate masks, each of which is transformed in the same
        way as the science data. A pixel is flagged if it had greater than
        1% influence from a bad pixel. The transformed masks are then added
        back together to generate the transformed DQ plane.
        
        In order not to lose any data, the output image arrays (including the
        reference image's) are expanded with respect to the input image arrays.
        The science and variance data arrays are padded with zeros; the DQ
        plane is padded with ones.
        
        The WCS keywords in the headers of the output images are updated
        to reflect the transformation.
        
        :param interpolator: type of interpolation desired
        :type interpolator: string, possible values are None, 'nearest', 
                            'linear', 'spline2', 'spline3', 'spline4', 
                            or 'spline5'
        
        :param trim_data: flag to indicate whether output image should be trimmed
                          to the size of the reference image.
        :type trim_data: Boolean
        """
        
        # Instantiate the log
        log = gemLog.getGeminiLog(logType=rc["logType"],
                                  logLevel=rc["logLevel"])
        
        # Log the standard "starting primitive" debug message
        log.debug(gt.log_message("primitive", "alignToReferenceFrame",
                                 "starting"))
        
        # Define the keyword to be used for the time stamp for this primitive
        timestamp_key = self.timestamp_keys["alignToReferenceFrame"]

        # Initialize the list of output AstroData objects
        adoutput_list = []
        
        # Check whether two or more input AstroData objects were provided
        adinput = rc.get_inputs_as_astrodata()
        if len(adinput) <= 1:
            log.warning("No alignment will be performed, since at least two " \
                        "input AstroData objects are required for " \
                        "alignToReferenceFrame")
            # Set the input AstroData object list equal to the output AstroData
            # objects list without further processing
            adoutput_list = adinput
        else:
            
            # Get the necessary parameters from the RC
            interpolator = rc["interpolator"]
            trim_data = rc["trim_data"]

            # make sure all images have one science extension
            for ad in adinput:
                sci_exts = ad["SCI"]
                if sci_exts is None or len(sci_exts)!=1:
                    raise Errors.InputError("Input images must have only one " +
                                            "SCI extension.")
            
            # load ndimage package if there will be interpolation
            if interpolator=="None":
                interpolator = None
            if interpolator is not None:
                from scipy.ndimage import affine_transform
            
            # get reference WCS and shape
            reference = adinput[0]
            ref_wcs = pywcs.WCS(reference["SCI"].header)
            ref_shape = reference["SCI"].data.shape
            ref_corners = at.get_corners(ref_shape)
            naxis = len(ref_shape)

            # first pass: get output image shape required to fit all
            # data in output by transforming corner coordinates of images
            all_corners = [ref_corners]
            xy_img_corners = []
            shifts = []
            for i in range(1,len(adinput)):
                
                ad = adinput[i]
                
                img_wcs = pywcs.WCS(ad["SCI"].header)
                
                img_shape = ad["SCI"].data.shape
                img_corners = at.get_corners(img_shape)
                
                xy_corners = [(corner[1],corner[0]) for corner in img_corners]
                xy_img_corners.append(xy_corners)
                
                if interpolator is None:
                    # find shift by transforming center position of field
                    # (so that center matches best)
                    x1y1 = np.array([img_shape[1]/2.0,img_shape[0]/2.0])
                    x2y2 = img_wcs.wcs_sky2pix(
                        ref_wcs.wcs_pix2sky([x1y1],1),1)[0]
                    
                    # round shift to nearest integer and flip x and y
                    offset = np.roll(np.rint(x2y2-x1y1),1)
                    
                    # shift corners of image
                    img_corners = [tuple(offset+corner) 
                                   for corner in img_corners]
                    shifts.append(offset)
                else:
                    # transform corners of image via WCS
                    xy_corners = img_wcs.wcs_sky2pix(ref_wcs.wcs_pix2sky(
                                                               xy_corners,0),0)
                    
                    img_corners = [(corner[1],corner[0]) 
                                   for corner in xy_corners]
                
                all_corners.append(img_corners)
            
            # If data should be trimmed to size of reference image,
            # output shape is same as ref_shape, and centering offsets are zero
            if trim_data:
                cenoff=[0]*naxis
                out_shape = ref_shape
            else:
                # Otherwise, use the corners of the images to get the minimum
                # required output shape to hold all data
                cenoff = []
                out_shape = []
                for axis in range(naxis):
                    # get output shape from corner values
                    cvals = [
                        corner[axis] for ic in all_corners for corner in ic]
                    out_shape.append(int(max(cvals)-min(cvals)+1))
                    
                    # if just shifting, need to set centering shift
                    # for reference image from offsets already calculated
                    if interpolator is None:
                        svals = [shift[axis] for shift in shifts]
                        # include a 0 shift for the reference image
                        # (in case it's already centered)
                        svals.append(0.0)
                        cenoff.append(-int(max(svals)))
            
                out_shape = tuple(out_shape)
            
                # if not shifting, get offset required to center reference image
                # from the size of the image
                if interpolator is not None:
                    incen = [0.5*(axlen-1) for axlen in ref_shape]
                    outcen = [0.5*(axlen-1) for axlen in out_shape]
                    cenoff = np.rint(incen) - np.rint(outcen)

            # shift the reference image to keep it in the center
            # of the new array (do the same for VAR and DQ)

            if trim_data:
                log.fullinfo("Trimming data to size of reference image")
            else:
                log.fullinfo("Growing reference image to keep all data; " +
                             "centering data, and updating WCS to account " +
                             "for shift")
                log.fullinfo("New output shape: "+repr(out_shape))
            
            ref_corners = [(corner[1]-cenoff[1]+1,corner[0]-cenoff[0]+1) # x,y
                           for corner in ref_corners]
            log.fullinfo("Setting AREA keywords in header to denote original " +
                         "data area.")
            area_keys = []
            log.fullinfo("AREATYPE = 'P4'     / Polygon with 4 vertices")
            area_keys.append(("AREATYPE","P4","Polygon with 4 vertices"))
            for i in range(len(ref_corners)):
                for axis in range(len(ref_corners[i])):
                    key_name = "AREA%i_%i" % (i+1,axis+1)
                    key_value = ref_corners[i][axis]
                    key_comment = "Vertex %i, dimension %i" % (i+1,axis+1)
                    area_keys.append((key_name,key_value,key_comment))
                    log.fullinfo("%-8s = %7.2f  / %s" % 
                                 (key_name, key_value,key_comment))
            
            for ext in reference:
                if ext.extname() not in ["SCI","VAR","DQ"]:
                    continue
                
                ref_data = ext.data
                
                # Make a blank data array to transform into
                if ext.extname()=="DQ":
                    # pad the DQ plane with 1 instead of 0, and make the data
                    # type int16
                    trans_data = np.zeros(out_shape).astype(np.int16)
                    trans_data += 1
                else:
                    trans_data = np.zeros(out_shape).astype(np.float32)
                
                trans_data[int(-cenoff[0]):int(ref_shape[0]-cenoff[0]),
                           int(-cenoff[1]):int(ref_shape[1]-cenoff[1])] = \
                           ref_data
                
                ext.data = trans_data
                
                # update the WCS in the reference image to account for the shift
                ext.set_key_value("CRPIX1", ref_wcs.wcs.crpix[0]-cenoff[1],
                                  comment=self.keyword_comments["CRPIX1"])
                ext.set_key_value("CRPIX2", ref_wcs.wcs.crpix[1]-cenoff[0],
                                  comment=self.keyword_comments["CRPIX2"])
                
                # set area keywords
                for key in area_keys:
                    ext.set_key_value(key[0],key[1],key[2])
            
            # update the WCS in the PHU as well
            reference.phu_set_key_value(
                "CRPIX1", ref_wcs.wcs.crpix[0]-cenoff[1],
                comment=self.keyword_comments["CRPIX1"])
            reference.phu_set_key_value(
                "CRPIX2", ref_wcs.wcs.crpix[1]-cenoff[0],
                comment=self.keyword_comments["CRPIX2"])
            
            out_wcs = pywcs.WCS(reference["SCI"].header)
            
            # Change the reference filename and append it to the output list
            reference.filename = gt.filename_updater(
                adinput=reference, suffix=rc["suffix"], strip=True)
            adoutput_list.append(reference)
            
            # now transform the data
            for i in range(1,len(adinput)):
                
                log.fullinfo("Starting alignment for "+ adinput[i].filename)
                
                ad = adinput[i]
                
                sciext = ad["SCI"]
                img_wcs = pywcs.WCS(sciext.header)
                img_shape = sciext.data.shape
                
                if interpolator is None:
                    
                    # recalculate shift from new reference wcs
                    x1y1 = np.array([img_shape[1]/2.0,img_shape[0]/2.0])
                    x2y2 = img_wcs.wcs_sky2pix(
                        out_wcs.wcs_pix2sky([x1y1],1),1)[0]
                    
                    shift = np.roll(np.rint(x2y2-x1y1),1)
                    if np.any(shift>0):
                        log.warning("Shift was calculated to be >0; " +
                                    "interpolator=None "+
                                    "may not be appropriate for this data.")
                        shift = np.where(shift>0,0,shift)
                    
                    # update PHU WCS keywords
                    log.fullinfo("Offsets: "+repr(np.roll(shift,1)))
                    log.fullinfo("Updating WCS to track shift in data")
                    ad.phu_set_key_value(
                        "CRPIX1", img_wcs.wcs.crpix[0]-shift[1],
                        comment=self.keyword_comments["CRPIX1"])
                    ad.phu_set_key_value(
                        "CRPIX2", img_wcs.wcs.crpix[1]-shift[0],
                        comment=self.keyword_comments["CRPIX2"])
                
                else:
                    # get transformation matrix from composite of wcs's
                    # matrix = in_sky2pix*out_pix2sky (converts output to input)
                    xy_matrix = np.dot(
                        np.linalg.inv(img_wcs.wcs.cd),out_wcs.wcs.cd)
                    # switch x and y for compatibility with numpy ordering
                    flip_xy = np.roll(np.eye(2),2)
                    matrix = np.dot(flip_xy,np.dot(xy_matrix,flip_xy))
                    matrix_det = np.linalg.det(matrix)
                    
                    # offsets: shift origin of transformation to the reference
                    # pixel by subtracting the transformation of the output
                    # reference pixel and adding the input reference pixel
                    # back in
                    refcrpix = np.roll(out_wcs.wcs.crpix,1)
                    imgcrpix = np.roll(img_wcs.wcs.crpix,1)
                    offset = imgcrpix - np.dot(matrix,refcrpix)
                    
                    # then add in the shift of origin due to dithering offset.
                    # This is the transform of the reference CRPIX position,
                    # minus the original position
                    trans_crpix = img_wcs.wcs_sky2pix(
                        out_wcs.wcs_pix2sky([out_wcs.wcs.crpix],1),1)[0]
                    trans_crpix = np.roll(trans_crpix,1)
                    offset = offset + trans_crpix-imgcrpix
                    
                    # Since the transformation really is into the reference
                    # WCS coordinate system as near as possible, just set image
                    # WCS equal to reference WCS
                    log.fullinfo("Offsets: "+repr(np.roll(offset,1)))
                    log.fullinfo("Transformation matrix:\n"+repr(matrix))
                    log.fullinfo("Updating WCS to match reference WCS")
                    ad.phu_set_key_value(
                        "CRPIX1", out_wcs.wcs.crpix[0],
                        comment=self.keyword_comments["CRPIX1"])
                    ad.phu_set_key_value(
                        "CRPIX2", out_wcs.wcs.crpix[1],
                        comment=self.keyword_comments["CRPIX2"])
                    ad.phu_set_key_value(
                        "CRVAL1", out_wcs.wcs.crval[0],
                        comment=self.keyword_comments["CRVAL1"])
                    ad.phu_set_key_value(
                        "CRVAL2", out_wcs.wcs.crval[1],
                        comment=self.keyword_comments["CRVAL2"])
                    ad.phu_set_key_value(
                        "CD1_1", out_wcs.wcs.cd[0,0],
                        comment=self.keyword_comments["CD1_1"])
                    ad.phu_set_key_value(
                        "CD1_2", out_wcs.wcs.cd[0,1],
                        comment=self.keyword_comments["CD1_2"])
                    ad.phu_set_key_value(
                        "CD2_1", out_wcs.wcs.cd[1,0],
                        comment=self.keyword_comments["CD2_1"])
                    ad.phu_set_key_value(
                        "CD2_2", out_wcs.wcs.cd[1,1],
                        comment=self.keyword_comments["CD2_2"])
                
                # transform corners to find new location of original data
                data_corners = out_wcs.wcs_sky2pix(
                    img_wcs.wcs_pix2sky(xy_img_corners[i-1],0),1)
                log.fullinfo("Setting AREA keywords in header to denote " +
                             "original data area.")
                area_keys = []
                log.fullinfo("AREATYPE = 'P4'     / Polygon with 4 vertices")
                area_keys.append(("AREATYPE","P4","Polygon with 4 vertices"))
                for i in range(len(data_corners)):
                    for axis in range(len(data_corners[i])):
                        key_name = "AREA%i_%i" % (i+1,axis+1)
                        key_value = data_corners[i][axis]
                        key_comment = "Vertex %i, dimension %i" % (i+1,axis+1)
                        area_keys.append((key_name,key_value,key_comment))
                        log.fullinfo("%-8s = %7.2f  / %s" % 
                                     (key_name, key_value,key_comment))
                
                for ext in ad:
                    extname = ext.extname()
                    
                    if extname not in ["SCI","VAR","DQ"]:
                        continue
                    
                    log.fullinfo("Transforming "+ad.filename+"["+extname+"]")
                    
                    # Access pixel data
                    img_data = ext.data
                    
                    if interpolator is None:
                        # just shift the data by an integer number of pixels
                        # (useful for noisy data, also lightning fast)
                        
                        # Make a blank data array to transform into
                        if extname=="DQ":
                            # pad the DQ plane with 1 instead of 0, and
                            # make the data type int16
                            trans_data = np.zeros(out_shape).astype(np.int16)
                            trans_data += 1
                        else:
                            trans_data = np.zeros(out_shape).astype(np.float32)
                        
                        trans_data[int(-shift[0]):int(img_shape[0]
                                                      -shift[0]),
                                   int(-shift[1]):int(img_shape[1]
                                                      -shift[1])] = img_data
                        
                        matrix_det = 1.0
                        
                        # update the wcs to track the transformation
                        ext.set_key_value(
                            "CRPIX1", img_wcs.wcs.crpix[0]-shift[1],
                            comment=self.keyword_comments["CRPIX1"])
                        ext.set_key_value(
                            "CRPIX2", img_wcs.wcs.crpix[1]-shift[0],
                            comment=self.keyword_comments["CRPIX2"])
                    
                    else:
                        # use ndimage to interpolate values
                        
                        # Interpolation method is determined by 
                        # interpolator parameter
                        if interpolator=="nearest":
                            order = 0
                        elif interpolator=="linear":
                            order = 1
                        elif interpolator=="spline2":
                            order = 2
                        elif interpolator=="spline3":
                            order = 3
                        elif interpolator=="spline4":
                            order = 4
                        elif interpolator=="spline5":
                            order = 5
                        else:
                            raise Errors.InputError("Interpolation method " +
                                                    interpolator +
                                                    " not recognized.")
                        
                        if extname=="DQ":
                            
                            # DQ flags are set bit-wise
                            # bit 1: bad pixel (1)
                            # bit 2: nonlinear (2)
                            # bit 3: saturated (4)
                            # A pixel can be 0 (good, no flags), or the sum of
                            # any of the above flags 
                            # (or any others I don't know about)
                            
                            # unpack the DQ data into separate masks
                            # NOTE: this method only works for 8-bit masks!
                            unp = (img_shape[0],img_shape[1],8)
                            unpack_data = np.unpackbits(
                                np.uint8(img_data)).reshape(unp)
                            
                            # transform each mask
                            trans_data = np.zeros(out_shape).astype(np.int16)
                            for j in range(0,8):
                                
                                # skip the transformation if there are no flags
                                # set (but always do the bad pixel mask because
                                # it is needed to mask the part of the array
                                #  that was padded out to match the reference
                                # image)
                                if not unpack_data[:,:,j].any() and j!=7:
                                    # first bit is j=7 because unpack
                                    # is backwards 
                                    continue
                                
                                mask = np.float32(unpack_data[:,:,j])
                                
                                # if bad pix bit, pad with 1. 
                                # Otherwise, pad with 0
                                if j==7:
                                    cval = 1
                                else:
                                    cval = 0
                                trans_mask = affine_transform(
                                    mask, matrix, offset=offset,
                                    output_shape=out_shape, order=order,
                                    cval=cval)
                                del mask; mask = None
                                
                                # flag any pixels with >1% influence
                                # from bad pixel
                                trans_mask = np.where(np.abs(trans_mask)>0.01,
                                                      2**(7-j),0)
                                
                                # add the flags into the overall mask
                                trans_data += trans_mask
                                del trans_mask; trans_mask = None
    
                        else:
                            
                            # transform science and variance data in the
                            # same way
                            cval = 0.0
                            trans_data = affine_transform(
                                img_data, matrix, offset=offset,
                                output_shape=out_shape, order=order, cval=cval)
                        
                        # update the wcs
                        ext.set_key_value(
                            "CRPIX1", out_wcs.wcs.crpix[0],
                            comment=self.keyword_comments["CRPIX1"])
                        ext.set_key_value(
                            "CRPIX2", out_wcs.wcs.crpix[1],
                            comment=self.keyword_comments["CRPIX2"])
                        ext.set_key_value(
                            "CRVAL1", out_wcs.wcs.crval[0],
                            comment=self.keyword_comments["CRVAL1"])
                        ext.set_key_value(
                            "CRVAL2", out_wcs.wcs.crval[1],
                            comment=self.keyword_comments["CRVAL2"])
                        ext.set_key_value(
                            "CD1_1", out_wcs.wcs.cd[0,0],
                            comment=self.keyword_comments["CD1_1"])
                        ext.set_key_value(
                            "CD1_2", out_wcs.wcs.cd[0,1],
                            comment=self.keyword_comments["CD1_2"])
                        ext.set_key_value(
                            "CD2_1", out_wcs.wcs.cd[1,0],
                            comment=self.keyword_comments["CD2_1"])
                        ext.set_key_value(
                            "CD2_2", out_wcs.wcs.cd[1,1],
                            comment=self.keyword_comments["CD2_2"])
                        
                        # set area keywords
                        for key in area_keys:
                            ext.set_key_value(key[0],key[1],key[2])
                    
                    ext.data = trans_data
                
                # if there was any scaling in the transformation, the
                # pixel size will have changed, and the output should
                # be scaled by the ratio of input pixel size to output
                # pixel size to conserve the total flux in a feature.
                # This factor is the determinant of the transformation
                # matrix.
                if (1.0-matrix_det)>1e-6:
                    log.fullinfo("Multiplying by %f to conserve flux" %
                                 matrix_det)
                    
                    # Allow the arith toolbox to do the multiplication
                    # so that variance is handled correctly
                    ad.mult(matrix_det)
                
                # Add time stamp to PHU
                gt.mark_history(adinput=ad, keyword=timestamp_key)

                # Change the filename
                ad.filename = gt.filename_updater(adinput=ad, 
                                                  suffix=rc["suffix"], 
                                                  strip=True)

                # Append the output AstroData object to the list
                # of output AstroData objects
                adoutput_list.append(ad)
        
        
        # Report the list of output AstroData objects to the reduction
        # context
        rc.report_output(adoutput_list)
        
        yield rc
Beispiel #9
0
    def resampleToCommonFrame(self, adinputs=None, **params):
        """
        This primitive applies the transformation encoded in the input images
        WCSs to align them with a reference image, in reference image pixel
        coordinates. The reference image is taken to be the first image in
        the input list.
        
        By default, the transformation into the reference frame is done via
        interpolation. The interpolator parameter specifies the interpolation 
        method. The options are nearest-neighbor, bilinear, or nth-order 
        spline, with n = 2, 3, 4, or 5. If interpolator is None, 
        no interpolation is done: the input image is shifted by an integer
        number of pixels, such that the center of the frame matches up as
        well as possible. The variance plane, if present, is transformed in
        the same way as the science data.
        
        The data quality plane, if present, must be handled a little
        differently. DQ flags are set bit-wise, such that each pixel is the 
        sum of any of the following values: 0=good pixel,
        1=bad pixel (from bad pixel mask), 2=nonlinear, 4=saturated, etc.
        To transform the DQ plane without losing flag information, it is
        unpacked into separate masks, each of which is transformed in the same
        way as the science data. A pixel is flagged if it had greater than
        1% influence from a bad pixel. The transformed masks are then added
        back together to generate the transformed DQ plane.
        
        In order not to lose any data, the output image arrays (including the
        reference image's) are expanded with respect to the input image arrays.
        The science and variance data arrays are padded with zeros; the DQ
        plane is padded with 16s.
        
        The WCS keywords in the headers of the output images are updated
        to reflect the transformation.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        interpolator: str
            desired interpolation [nearest | linear | spline2 | spline3 |
                                   spline4 | spline5]
        trim_data: bool
            trim image to size of reference image?
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        interpolator = params["interpolator"]
        trim_data = params["trim_data"]
        sfx = params["suffix"]

        if len(adinputs) < 2:
            log.warning("No alignment will be performed, since at least two "
                        "input AstroData objects are required for "
                        "resampleToCommonFrame")
            return adinputs

        if not all(len(ad)==1 for ad in adinputs):
            raise IOError("All input images must have only one extension.")

        # --------------------  BEGIN establish reference frame  -------------------
        ref_image = adinputs[0]
        ref_wcs = WCS(ref_image[0].hdr)
        ref_shape = ref_image[0].data.shape
        ref_corners = at.get_corners(ref_shape)
        naxis = len(ref_shape)

        # first pass: get output image shape required to fit all
        # data in output by transforming corner coordinates of images
        all_corners = [ref_corners]
        corner_values = _transform_corners(adinputs[1:], all_corners, ref_wcs,
                                           interpolator)
        all_corners, xy_img_corners, shifts = corner_values
        refoff, out_shape = _shifts_and_shapes(all_corners, ref_shape, naxis,
                                               interpolator, trim_data, shifts)
        ref_corners = [(corner[1] - refoff[1] + 1, corner[0] - refoff[0] + 1) # x,y
                       for corner in ref_corners]
        area_keys = _build_area_keys(ref_corners)

        ref_image.hdr.set('CRPIX1', ref_wcs.wcs.crpix[0]-refoff[1],
                          self.keyword_comments["CRPIX1"])
        ref_image.hdr.set('CRPIX2', ref_wcs.wcs.crpix[1]-refoff[0],
                          self.keyword_comments["CRPIX2"])
        padding = tuple((int(-cen),out-int(ref-cen)) for cen, out, ref in
                        zip(refoff, out_shape, ref_shape))
        _pad_image(ref_image, padding)

        for key in area_keys:
            ref_image[0].hdr.set(*key)
        out_wcs = WCS(ref_image[0].hdr)
        ref_image.update_filename(suffix=sfx, strip=True)
        # -------------------- END establish reference frame -----------------------

        # --------------------   BEGIN transform data ...  -------------------------
        for ad, corners in zip(adinputs[1:], xy_img_corners):
            if interpolator:
                trans_parameters = _composite_transformation_matrix(ad,
                                        out_wcs, self.keyword_comments)
                matrix, matrix_det, img_wcs, offset = trans_parameters
            else:
                shift = _composite_from_ref_wcs(ad, out_wcs,
                                                self.keyword_comments)
                matrix_det = 1.0

            # transform corners to find new location of original data
            data_corners = out_wcs.all_world2pix(
                img_wcs.all_pix2world(corners, 0), 1)
            area_keys = _build_area_keys(data_corners)

            if interpolator:
                kwargs = {'matrix': matrix, 'offset': offset,
                          'order': interpolators[interpolator],
                          'output_shape': out_shape}
                new_var = None if ad[0].variance is None else \
                    affine_transform(ad[0].variance, cval=0.0, **kwargs)
                new_mask = _transform_mask(ad[0].mask if ad[0].mask is not None
                        else np.zeros_like(ad[0].data, dtype=DQ.datatype), **kwargs)
                if hasattr(ad[0], 'OBJMASK'):
                    ad[0].OBJMASK = _transform_mask(ad[0].OBJMASK, **kwargs)
                ad[0].reset(affine_transform(ad[0].data, cval=0.0, **kwargs),
                            new_mask, new_var)
            else:
                padding = tuple((int(-s), out-int(img-s)) for s, out, img in
                                zip(shift, out_shape, ad[0].data.shape))
                _pad_image(ad, padding)

            if abs(1.0 - matrix_det) > 1e-6:
                    log.fullinfo("Multiplying by {} to conserve flux".format(matrix_det))
                    # Allow the arith toolbox to do the multiplication
                    # so that variance is handled correctly
                    ad.multiply(matrix_det)

            for key in area_keys:
                ref_image[0].hdr.set(*key)

            # Timestamp and update filename
            ad.update_filename(suffix=sfx, strip=True)
        return adinputs
Beispiel #10
0
 def test_get_corners_3d(self):
     corners = astrotools.get_corners((300, 500, 400))
     expected_corners = [(0, 0, 0), (299, 0, 0), (0, 499, 0), (299, 499, 0),
                         (0, 0, 399), (299, 0, 399), (0, 499, 399),
                         (299, 499, 399)]
     assert corners == expected_corners
Beispiel #11
0
 def test_get_corners_2d(self):
     corners = astrotools.get_corners((300, 500))
     assert corners == [(0, 0), (299, 0), (0, 499), (299, 499)]
Beispiel #12
0
    def resampleToCommonFrame(self, adinputs=None, **params):
        """
        This primitive applies the transformation encoded in the input images
        WCSs to align them with a reference image, in reference image pixel
        coordinates. The reference image is taken to be the first image in
        the input list.
        
        By default, the transformation into the reference frame is done via
        interpolation. The interpolator parameter specifies the interpolation 
        method. The options are nearest-neighbor, bilinear, or nth-order 
        spline, with n = 2, 3, 4, or 5. If interpolator is None, 
        no interpolation is done: the input image is shifted by an integer
        number of pixels, such that the center of the frame matches up as
        well as possible. The variance plane, if present, is transformed in
        the same way as the science data.
        
        The data quality plane, if present, must be handled a little
        differently. DQ flags are set bit-wise, such that each pixel is the 
        sum of any of the following values: 0=good pixel,
        1=bad pixel (from bad pixel mask), 2=nonlinear, 4=saturated, etc.
        To transform the DQ plane without losing flag information, it is
        unpacked into separate masks, each of which is transformed in the same
        way as the science data. A pixel is flagged if it had greater than
        1% influence from a bad pixel. The transformed masks are then added
        back together to generate the transformed DQ plane.
        
        In order not to lose any data, the output image arrays (including the
        reference image's) are expanded with respect to the input image arrays.
        The science and variance data arrays are padded with zeros; the DQ
        plane is padded with 16s.
        
        The WCS keywords in the headers of the output images are updated
        to reflect the transformation.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        interpolator: str
            desired interpolation [nearest | linear | spline2 | spline3 |
                                   spline4 | spline5]
        trim_data: bool
            trim image to size of reference image?
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        interpolator = params["interpolator"]
        trim_data = params["trim_data"]
        sfx = params["suffix"]

        if len(adinputs) < 2:
            log.warning("No alignment will be performed, since at least two "
                        "input AstroData objects are required for "
                        "resampleToCommonFrame")
            return adinputs

        if not all(len(ad) == 1 for ad in adinputs):
            raise IOError("All input images must have only one extension.")

        # --------------------  BEGIN establish reference frame  -------------------
        ref_image = adinputs[0]
        ref_wcs = WCS(ref_image[0].hdr)
        ref_shape = ref_image[0].data.shape
        ref_corners = at.get_corners(ref_shape)
        naxis = len(ref_shape)

        # first pass: get output image shape required to fit all
        # data in output by transforming corner coordinates of images
        all_corners = [ref_corners]
        corner_values = _transform_corners(adinputs[1:], all_corners, ref_wcs,
                                           interpolator)
        all_corners, xy_img_corners, shifts = corner_values
        refoff, out_shape = _shifts_and_shapes(all_corners, ref_shape, naxis,
                                               interpolator, trim_data, shifts)
        ref_corners = [
            (corner[1] - refoff[1] + 1, corner[0] - refoff[0] + 1)  # x,y
            for corner in ref_corners
        ]
        area_keys = _build_area_keys(ref_corners)

        ref_image.hdr.set('CRPIX1', ref_wcs.wcs.crpix[0] - refoff[1],
                          self.keyword_comments["CRPIX1"])
        ref_image.hdr.set('CRPIX2', ref_wcs.wcs.crpix[1] - refoff[0],
                          self.keyword_comments["CRPIX2"])
        padding = tuple((int(-cen), out - int(ref - cen))
                        for cen, out, ref in zip(refoff, out_shape, ref_shape))
        _pad_image(ref_image, padding)

        for key in area_keys:
            ref_image[0].hdr.set(*key)
        out_wcs = WCS(ref_image[0].hdr)
        ref_image.update_filename(suffix=sfx, strip=True)
        # -------------------- END establish reference frame -----------------------

        # --------------------   BEGIN transform data ...  -------------------------
        for ad, corners in zip(adinputs[1:], xy_img_corners):
            if interpolator:
                trans_parameters = _composite_transformation_matrix(
                    ad, out_wcs, self.keyword_comments)
                matrix, matrix_det, img_wcs, offset = trans_parameters
            else:
                shift = _composite_from_ref_wcs(ad, out_wcs,
                                                self.keyword_comments)
                matrix_det = 1.0

            # transform corners to find new location of original data
            data_corners = out_wcs.all_world2pix(
                img_wcs.all_pix2world(corners, 0), 1)
            area_keys = _build_area_keys(data_corners)

            if interpolator:
                kwargs = {
                    'matrix': matrix,
                    'offset': offset,
                    'order': interpolators[interpolator],
                    'output_shape': out_shape
                }
                new_var = None if ad[0].variance is None else \
                    affine_transform(ad[0].variance, cval=0.0, **kwargs)
                new_mask = _transform_mask(
                    ad[0].mask if ad[0].mask is not None else np.zeros_like(
                        ad[0].data, dtype=DQ.datatype), **kwargs)
                if hasattr(ad[0], 'OBJMASK'):
                    ad[0].OBJMASK = _transform_mask(ad[0].OBJMASK, **kwargs)
                ad[0].reset(affine_transform(ad[0].data, cval=0.0, **kwargs),
                            new_mask, new_var)
            else:
                padding = tuple(
                    (int(-s), out - int(img - s))
                    for s, out, img in zip(shift, out_shape, ad[0].data.shape))
                _pad_image(ad, padding)

            if abs(1.0 - matrix_det) > 1e-6:
                log.fullinfo(
                    "Multiplying by {} to conserve flux".format(matrix_det))
                # Allow the arith toolbox to do the multiplication
                # so that variance is handled correctly
                ad.multiply(matrix_det)

            for key in area_keys:
                ref_image[0].hdr.set(*key)

            # Timestamp and update filename
            ad.update_filename(suffix=sfx, strip=True)
        return adinputs