Ejemplo n.º 1
0
def _transform_corners(ads, all_corners, ref_wcs, interpolator):
    shifts = []
    xy_img_corners = []

    for ad in ads:
        img_wcs = WCS(ad[0].hdr)
        img_shape = ad[0].data.shape
        img_corners = at.get_corners(img_shape)
        xy_corners  = [(corner[1],corner[0]) for corner in img_corners]
        xy_img_corners.append(xy_corners)

        if interpolator is None:
            # find shift by transforming center position of field
            # (so that center matches best)
            x1y1 = np.array([img_shape[1]/2.0, img_shape[0]/2.0])
            x2y2 = img_wcs.all_world2pix(ref_wcs.all_pix2world([x1y1],1), 1)[0]

            # round shift to nearest integer and flip x and y
            offset = np.roll(np.rint(x2y2-x1y1),1)

            # shift corners of image
            img_corners = [tuple(offset+corner) for corner in img_corners]
            shifts.append(offset)
        else:
            # transform corners of image via WCS
            xy_corners = img_wcs.all_world2pix(ref_wcs.all_pix2world(xy_corners,0),0)
            img_corners = [(corner[1],corner[0]) for corner in xy_corners]

        all_corners.append(img_corners)
    return all_corners, xy_img_corners, shifts
def getCompleteness(brick, raR, decR):
    """ get the completeness values for a list of ra and dec in a brick """
    # make random points in delimit ra dec min max
    dpl_g = "../decalsDepthMasks/decals-" + brick + "-depth-g.fits.gz"
    dpl_r = "../decalsDepthMasks/decals-" + brick + "-depth-r.fits.gz"
    dpl_z = "../decalsDepthMasks/decals-" + brick + "-depth-z.fits.gz"
    # print dpl_g
    w = WCS(dpl_g)
    imageG = fits.open(dpl_g)
    x, y = w.all_world2pix(raR, decR, 0)
    value_g = n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
    imageG.close()
    # print "1"
    w = WCS(dpl_r)
    imageG = fits.open(dpl_r)
    x, y = w.all_world2pix(raR, decR, 0)
    value_r = n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
    imageG.close()
    # print "2"
    w = WCS(dpl_z)
    imageG = fits.open(dpl_z)
    x, y = w.all_world2pix(raR, decR, 0)
    value_z = n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
    imageG.close()
    # print "comp ok"
    return value_g, value_r, value_z
Ejemplo n.º 3
0
def test_sip2pv():
    """
    Test conversion of sip 2 pv keywords, ensure that both provide same ra/dec <--> x/y transforms.
    """

    sip_header = fits.Header.fromtextfile(os.path.join(dir_name, 'data/IRAC_3.6um_sip.txt'))
    control_header = sip_header.copy()
    naxis1 = sip_header['NAXIS1']
    naxis2 = sip_header['NAXIS2']
    x = np.linspace(1, naxis1, 10)
    y = np.linspace(1, naxis2, 10)
    xx, yy = np.meshgrid(x, y)
    pixargs = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T

    sip_to_pv(sip_header)

    wsip = WCS(sip_header)
    wtpv = WCS(control_header)

    world1 = wsip.all_pix2world(pixargs, 1)
    world2 = wtpv.all_pix2world(pixargs, 1)

    npt.assert_equal(world1, world2)

    pix1 = wsip.all_world2pix(world1, 1)
    pix2 = wtpv.all_world2pix(world2, 1)
    npt.assert_almost_equal(pix1, pixargs, 4)
    npt.assert_almost_equal(pix2, pixargs, 4)
Ejemplo n.º 4
0
def test_pv2sip():
    """
    Test conversion of pv 2 sip keywords, check to see that world2pix transform round trips and is equal for pv and
     sip keywords.
    """

    pv_header = fits.Header.fromtextfile(os.path.join(dir_name, 'data/PTF_r_chip01_tpv.txt'))
    control_header = pv_header.copy()
    naxis1 = pv_header['NAXIS1']
    naxis2 = pv_header['NAXIS2']
    x = np.linspace(1, naxis1, 10)
    y = np.linspace(1, naxis2, 10)
    xx, yy = np.meshgrid(x, y)
    pixargs = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T

    pv_to_sip(pv_header)

    wsip = WCS(pv_header)
    wtpv = WCS(control_header)

    world1 = wsip.all_pix2world(pixargs, 1)
    world2 = wtpv.all_pix2world(pixargs, 1)

    npt.assert_equal(world1, world2)

    pix1 = wsip.all_world2pix(world1, 1)
    pix2 = wtpv.all_world2pix(world2, 1)

    npt.assert_almost_equal(pix1, pixargs, 4)
    npt.assert_almost_equal(pix2, pixargs, 4)
Ejemplo n.º 5
0
def get_facet_values(facet, ra, dec, root="facet", default=0):
    """
    Extract the value from a fits facet file
    """
    import numpy as np
    from astropy.io import fits
    from astropy.wcs import WCS

    # TODO: Check astropy version
    # TODO: Check facet is a fits file

    with fits.open(facet) as f:
        shape = f[0].data.shape

        w = WCS(f[0].header)
        freq = w.wcs.crval[2]
        stokes = w.wcs.crval[3]

        xe, ye, _1, _2 = w.all_world2pix(ra, dec, freq, stokes, 1)
        x, y = np.round(xe).astype(int), np.round(ye).astype(int)

        # Dummy value for points out of the fits area
        x[(x < 0) | (x >= shape[-1])] = -1
        y[(y < 0) | (y >= shape[-2])] = -1

        data = f[0].data[0,0,:,:]

        values = data[y, x]

        # Assign the default value to NaNs and points out of the fits area
        values[(x == -1) | (y == -1)] = default
        values[np.isnan(values)] = default

        #TODO: Flexible format for other data types ?
        return np.array(["{}_{:.0f}".format(root, val) for val in values])
Ejemplo n.º 6
0
def get_facet_values(facet, ra, dec, root="facet", default=0, pad_index=False):
    """
    Extract the value from a fits facet file
    """
    import numpy as np
    from astropy.io import fits
    from astropy.wcs import WCS

    # TODO: Check astropy version
    # TODO: Check facet is a fits file

    with fits.open(facet) as f:
        shape = f[0].data.shape

        w = WCS(f[0].header)
        if len(w.wcs.crval) == 4:
            freq = w.wcs.crval[2]
            stokes = w.wcs.crval[3]
            xe, ye, _1, _2 = w.all_world2pix(ra, dec, freq, stokes, 1)
        elif len(w.wcs.crval) == 2:
            xe, ye = w.all_world2pix(ra, dec, 1)
        else:
            raise ValueError('Input mask must have 2 axes (x, y) or 4 axes (x, y, freq, stokes).')
        x, y = np.round(xe).astype(int), np.round(ye).astype(int)

        # Dummy value for points out of the fits area
        x[(x < 0) | (x >= shape[-1])] = -1
        y[(y < 0) | (y >= shape[-2])] = -1

        if len(w.wcs.crval) == 4:
            data = f[0].data[0, 0, :, :]
        else:
            data = f[0].data[:, :]

        values = data[y, x]

        # Assign the default value to NaNs and points out of the fits area
        values[(x == -1) | (y == -1)] = default
        values[np.isnan(values)] = default

        if pad_index:
            return np.array(["{0}_{1}".format(root, str(int(val)).zfill(int(np.ceil(np.log10(len(set(values))+1))))) for val in values])
        else:
            return np.array(["{0}_{1}".format(root, int(val)) for val in values])
Ejemplo n.º 7
0
def find_mask_for_hdu(hdu, threshold=0.001, v1=-100.0, v2=100.0):
    mask = hdu.data < threshold

    # Cut out a window around line center
    w = WCS(hdu.header, key='A')
    waves = vels2waves([v1, v2], w.wcs.restwav, hdu.header)
    [i1, i2], _, _ = w.all_world2pix(waves, [0, 0], [0, 0], 0)
    i1, i2 = int(i1), int(i2) + 1
    mask[:, :, i1:i2] = False

    return mask
def makeRDCatalog(nR,brick,ramin,ramax,decmin,decmax):
	# make random points in delimit ra dec min max
	raR=n.random.uniform(low=ramin, high=ramax, size=nR)
	decR=n.random.uniform(low=decmin, high=decmax, size=nR)
	dpl_g="../decalsDepthMasks/decals-"+brick+"-depth-g.fits.gz"
	dpl_r="../decalsDepthMasks/decals-"+brick+"-depth-r.fits.gz"
	dpl_z="../decalsDepthMasks/decals-"+brick+"-depth-z.fits.gz"
	w = WCS(dpl_g)
	imageG=fits.open(dpl_g)
	x, y = w.all_world2pix(raR,decR, 0)
	value_g=n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
	imageG.close()
	w = WCS(dpl_r)
	imageG=fits.open(dpl_r)
	x, y = w.all_world2pix(raR,decR, 0)
	value_r=n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
	imageG.close()
	w = WCS(dpl_z)
	imageG=fits.open(dpl_z)
	x, y = w.all_world2pix(raR,decR, 0)
	value_z=n.array([imageG[0].data[int(x[ii])][int(y[ii])] for ii in range(len(x))])
	imageG.close()
	n.savetxt("../decalsRandoms/random-"+brick+".dat.gz", n.transpose([ raR , decR,value_g, value_r ,value_z]) , fmt= '%3.7f %3.7f %10.4f %10.4f %10.4f')
Ejemplo n.º 9
0
def analyze_polar_rotation(pole_fn, *args, **kwargs):
    """ Get celestial pole XY coordinates

    Args:
        pole_fn (str): FITS file of celestial pole

    Returns:
        tuple(int): A tuple of integers corresponding to the XY pixel position
        of celestial pole
    """
    img_utils.get_solve_field(pole_fn, **kwargs)

    wcs = WCS(pole_fn)

    pole_cx, pole_cy = wcs.all_world2pix(360, 90, 1)

    return pole_cx, pole_cy
Ejemplo n.º 10
0
def photometry(fits_file,radius):
        if os.path.isfile(fits_file):
        
            hdulist=fits.open(fits_file)
            ra = hdulist[0].header["RA"]
            dec=hdulist[0].header["DEC"]
            wcs=WCS(fits_file)
            coords = [ra,dec]
            w = wcs.all_world2pix(ra,dec,1)
            r=radius*u.kpc
            aperture=SkyCircularAperture(w,r*u.arcsec)
            exp_time=hdulist[0].header["EXPTIME"]
            photo_error=np.sqrt(hdulist[0].data/exp_time)
            photo_error=np.sqrt(hdulist[0].data/exp_time)
            phot_table=aperture_photometry(hdulist[0],aperture,error=photo_error)
        else:
            print ("it is not a fit file")
Ejemplo n.º 11
0
def _composite_transformation_matrix(ad, out_wcs, keyword_comments):
    log = logutils.get_logger(__name__)
    img_wcs = WCS(ad[0].hdr)
    # get transformation matrix from composite of wcs's
    # matrix = in_sky2pix*out_pix2sky (converts output to input)
    xy_matrix = np.dot(np.linalg.inv(img_wcs.wcs.cd), out_wcs.wcs.cd)
    # switch x and y for compatibility with numpy ordering
    flip_xy = np.roll(np.eye(2), 2)
    matrix = np.dot(flip_xy,np.dot(xy_matrix, flip_xy))
    matrix_det = np.linalg.det(matrix)

    # offsets: shift origin of transformation to the reference
    # pixel by subtracting the transformation of the output
    # reference pixel and adding the input reference pixel
    # back in
    refcrpix = np.roll(out_wcs.wcs.crpix, 1)
    imgcrpix = np.roll(img_wcs.wcs.crpix, 1)
    offset = imgcrpix - np.dot(matrix, refcrpix)

    # then add in the shift of origin due to dithering offset.
    # This is the transform of the reference CRPIX position,
    # minus the original position
    trans_crpix = img_wcs.all_world2pix(
                  out_wcs.all_pix2world([out_wcs.wcs.crpix],1), 1)[0]
    trans_crpix = np.roll(trans_crpix, 1)
    offset = offset + trans_crpix-imgcrpix

    # Since the transformation really is into the reference
    # WCS coordinate system as near as possible, just set image
    # WCS equal to reference WCS
    log.fullinfo("Offsets: "+repr(np.roll(offset, 1)))
    log.fullinfo("Transformation matrix:\n"+repr(matrix))
    log.fullinfo("Updating WCS to match reference WCS")

    for ax in (1, 2):
        ad.hdr.set('CRPIX{}'.format(ax), out_wcs.wcs.crpix[ax-1],
                   comment=keyword_comments["CRPIX{}".format(ax)])
        ad.hdr.set('CRVAL{}'.format(ax), out_wcs.wcs.crval[ax-1],
                    comment=keyword_comments["CRVAL{}".format(ax)])
        for ax2 in (1, 2):
            ad.hdr.set('CD{}_{}'.format(ax,ax2), out_wcs.wcs.cd[ax-1,ax2-1],
                       comment=keyword_comments["CD{}_{}".format(ax,ax2)])

    return (matrix, matrix_det, img_wcs, offset) # ad ?
def grab_connected_postage_stamps(sci_frame, ref_frame, yc, xc, Yrange=50, Xrange=50):
    """Example function with types documented in the docstring.

        `PEP 484`_ type annotations are supported. If attribute, parameter, and
        return types are annotated according to `PEP 484`_, they do not need to be
        included in the docstring:

        Args:
            param1 (int): The first parameter.
            param2 (str): The second parameter.

        Returns:
            bool: The return value. True for success, False otherwise.

        .. _PEP 484:
            https://www.python.org/dev/peps/pep-0484/

    """
    sciWCS      = WCS(sci_frame[0].header)
    refWCS      = WCS(ref_frame[0].header)
    
    refdata     = ref_frame[0].data.copy()
    
    refdata[where(isnan(refdata))]  = median(refdata[where(~isnan(refdata))])
    
    scidata     = sci_frame[0].data.copy()
    scidata[where(isnan(scidata))]  = median(scidata[where(~isnan(scidata))])
    sciSubframe    = [  [int(round(yc-Yrange)), int(round(yc+Yrange))] ,
                        [int(round(xc-Xrange)), int(round(xc+Xrange))]]
    
    cometRA, cometDEC            = array(sciWCS.all_pix2world(xc,yc,zc))
    refPixCometX, refPixCometY   = array(refWCS.all_world2pix(cometRA, cometDEC, 0.0))
    
    refSubframe    = [  [int(round(refPixCometY - Yrange))+1,int(round(refPixCometY + Yrange))+1] ,
                        [int(round(refPixCometX - Xrange))+1,int(round(refPixCometX + Xrange))+1]]
    
    sciSubData  = scidata[sciSubframe[y][0]:sciSubframe[y][1],sciSubframe[x][0]:sciSubframe[x][1]]
    refSubData  = rot90(refdata[refSubframe[y][0]:refSubframe[y][1],refSubframe[x][0]:refSubframe[x][1]],2)
    
    return sciSubData, refSubData
Ejemplo n.º 13
0
def _composite_from_ref_wcs(ad, out_wcs, keyword_comments):
    log = logutils.get_logger(__name__)
    img_wcs = WCS(ad[0].hdr)
    img_shape = ad[0].data.shape

    # recalculate shift from new reference wcs
    x1y1 = np.array([img_shape[1] / 2.0, img_shape[0]/2.0])
    x2y2 = img_wcs.all_world2pix(out_wcs.all_pix2world([x1y1], 1), 1)[0]
    shift = np.roll(np.rint(x2y2 - x1y1), 1)
    if np.any(shift > 0):
        log.warning("Shift was calculated to be > 0; interpolator=None "
                    "may not be appropriate for this data.")
        shift = np.where(shift > 0, 0, shift)

    # update PHU WCS keywords
    log.fullinfo("Offsets: " + repr(np.roll(shift, 1)))
    log.fullinfo("Updating WCS to track shift in data")
    ad.hdr.set("CRPIX1", img_wcs.wcs.crpix[0] - shift[1],
                         comment=keyword_comments["CRPIX1"])
    ad.hdr.set("CRPIX2", img_wcs.wcs.crpix[1] - shift[0],
                         comment=keyword_comments["CRPIX2"])
    return shift  # ad ?
Ejemplo n.º 14
0
	
	print("Length of the dr2 table: %d"%len(dr2nearby))
	
	# Move table into a dictionary object
	dr2Objects = []
	for row in dr2nearby:
		IPHAS2name = row['IPHAS2']
		dr2Object={}
		dr2Object['name'] = IPHAS2name
		dr2Object['ra'] = row['RAJ2000']
		dr2Object['dec'] = row['DEJ2000']
		dr2Object['class'] = row['mergedClass']
		dr2Object['pStar'] = row['pStar']
		dr2Object['iClass'] = row['iClass']
		dr2Object['haClass'] = row['haClass']
		x, y = wcsSolution.all_world2pix([dr2Object['ra']], [dr2Object['dec']], 1)
		dr2Object['x'] = x[0]
		dr2Object['y'] = y[0]
		dr2Objects.append(dr2Object)
		
	
	
	# Run through all objects and find the ones that are haClass = "+1" but iClass != "+1" and overall class = "+1"
	extendedHaSources = []
	for index, d in enumerate(dr2Objects):
		if (d['class'] ==1) and (d['haClass'] == 1) and (d['iClass'] != 1):
			extendedHaSources.append(index)
	
	print("Found %d extended Ha sources out of %d total objects in DR2."%(len(extendedHaSources), len(dr2Objects)))
	
	
Ejemplo n.º 15
0
def imalign(images, square=False):
    from astropy.wcs import WCS
    from astropy.io import fits

    img_cat = images[0] # arbitrarily pick the first image to get a catalog with

    # fetch a gaia catalog for the full image footprint
    outputg = img_cat.f+'.gaia'
    gaia_cat = odi.get_gaia_coords(images[0], ota='None', inst='podi', output=outputg)
    # print gaia_cat
    # convert the ra, dec to image coordinates in each image
    # first get the wcs for each image

    x0s, y0s, xsizes, ysizes = np.zeros_like(images), np.zeros_like(images), np.zeros_like(images), np.zeros_like(images)

    for j,img in enumerate(images):
        hdu_img = fits.open(img.f)
        img.naxis1 = hdu_img[0].header['NAXIS1']
        img.naxis2 = hdu_img[0].header['NAXIS2']
        w_img = WCS(hdu_img[0].header)
        img.x_img, img.y_img = w_img.all_world2pix(gaia_cat.ra, gaia_cat.dec, 1)
        x0s[j], y0s[j] = img.x_img[0], img.y_img[0]
        hdu_img.close()
    x_ref = np.argmin(x0s)
    y_ref = np.argmin(y0s)

    # (pick the most positive image as a "reference")
    img_ref = images[x_ref]
    
    hdu_ref = fits.open(img_ref.f)
    naxis1_ref = hdu_ref[0].header['NAXIS1']
    naxis2_ref = hdu_ref[0].header['NAXIS2']
    w_ref = WCS(hdu_ref[0].header)
    x_ref, y_ref = w_ref.all_world2pix(gaia_cat.ra, gaia_cat.dec, 1)


    for j,img in enumerate(images):
        # compute the pair-wise integer pixel shifts between the image and the reference
        img.x_shift, img.y_shift = np.rint(np.median(x_ref-img.x_img)), np.rint(np.median(y_ref-img.y_img))
        img.x_std, img.y_std = np.rint(np.std(x_ref-img.x_img)), np.rint(np.std(y_ref-img.y_img))
        # figure out how wide the trimmed image is
        img.x_size = np.rint(img.naxis1 + img.x_shift)
        img.y_size = np.rint(img.naxis2 + img.y_shift)
        # keep track
        xsizes[j] = img.x_size
        ysizes[j] = img.y_size
        # shift the images so that they are aligned to the "reference"
        # use relative coordinates-- negative values are applied to the image, 
        # we will not change the 'reference' because we've already decided it shouldn't shift
        iraf.imcopy(img.f, 'temp{:1d}.fits'.format(j))
        if img.x_shift < 0.5:
            trim_img = 'temp{:1d}.fits[{:d}:{:d},*]'.format(j,int(abs(img.x_shift))+1, img.naxis1)
            iraf.imcopy(trim_img, 'temp{:1d}.fits'.format(j))
        else:
            raise Exception

        if img.y_shift < 0.5:
            trim_img = 'temp{:1d}.fits[*,{:d}:{:d}]'.format(j,int(abs(img.y_shift))+1, img.naxis2)
            iraf.imcopy(trim_img, 'temp{:1d}.fits'.format(j))
        else:
            raise Exception

    # figure out what the smallest image size is
    min_xsize = np.min(xsizes)
    min_ysize = np.min(ysizes)

    for j,img in enumerate(images):
        # then take any excess pixels off at the high end of the range in each dimension so that the images have identical dimensions
        new_img = img.f[:-5]+'_match.fits'
        trim_img = 'temp{:1d}.fits[1:{:d},1:{:d}]'.format(j,int(min_xsize)-1, int(min_ysize)-1) 
        # make the copy
        iraf.imcopy(trim_img, new_img)    
        # delete the temporary working images
        iraf.imdelete('temp{:1d}.fits'.format(j))
Ejemplo n.º 16
0
    def determineAstrometricSolution(self, adinputs=None, **params):
        """
        This primitive determines how to modify the WCS of each image to
        produce the best positional match between its sources (OBJCAT) and
        the REFCAT.

        Parameters
        ----------
        initial: float
            search radius for cross-correlation (arcsec)
        final: float
            search radius for object matching (arcsec)
        full_wcs: bool (or None)
            use an updated WCS for each matching iteration, rather than simply
            applying pixel-based corrections to the initial mapping?
            (None => not ('qa' in mode))
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]
        full_wcs = params["full_wcs"]

        for ad in adinputs:
            # Check we have a REFCAT and at least one OBJCAT to match
            try:
                refcat = ad.REFCAT
            except AttributeError:
                log.warning("No REFCAT in {} - cannot calculate astrometry".
                            format(ad.filename))
                continue
            if not any(hasattr(ext, 'OBJCAT') for ext in ad):
                log.warning("No OBJCATs in {} - cannot match to REFCAT".
                            format(ad.filename))
                continue

            # List of values to report to FITSstore
            info_list = []
            all_delta_ra = []
            all_delta_dec = []

            # Try to be clever here, and work on the extension with the
            # highest number of matches first, as this will give the most
            # reliable offsets, which can then be used to constrain the other
            # extensions. The problem is we don't know how many matches we'll
            # get until we do it, and that's slow, so use len(OBJCAT) as a proxy.
            objcat_lengths = [len(ext.OBJCAT) if hasattr(ext, 'OBJCAT') else 0
                              for ext in ad]
            objcat_order = np.argsort(objcat_lengths)[::-1]

            pixscale = ad.pixel_scale()
            initial = params["initial"] / pixscale  # Search box size
            final = params["final"] / pixscale  # Matching radius
            max_ref_sources = 100 if 'qa' in self.mode else None  # No more than this
            if full_wcs is None:
                full_wcs = not ('qa' in self.mode)

            best_model = (0, None)

            for index in objcat_order:
                ext = ad[index]
                extver = ext.hdr['EXTVER']
                try:
                    objcat = ad[index].OBJCAT
                except AttributeError:
                    log.stdinfo('No OBJCAT in {}:{} -- cannot perform '
                                'astrometry'.format(ad.filename, extver))
                    info_list.append({})
                    continue
                objcat_len = len(objcat)

                # The reference coordinates are always (x,y) pixels in the OBJCAT
                # Set up the input coordinates
                wcs = WCS(ad[index].hdr)
                xref, yref = refcat['RAJ2000'], refcat['DEJ2000']
                if not full_wcs:
                    xref, yref = wcs.all_world2pix(xref, yref, 1)

                # Now set up the initial model
                if full_wcs:
                    m_init = Pix2Sky(wcs, direction=-1)
                    m_init.factor.fixed = True
                    m_init.angle.fixed = True
                    if best_model[1] is None:
                        m_init.x_offset.bounds = (-initial, initial)
                        m_init.y_offset.bounds = (-initial, initial)
                    else:
                        # Copy parameters from best model to this model
                        # TODO: if rotation/scaling are used, the factor_ will need to
                        # be copied
                        for p in best_model[1].param_names:
                            setattr(m_init, p, getattr(best_model[1], p))
                else:
                    m_init = best_model[1]

                # Reduce the search space if we've previously found a match
                # TODO: This code is more generic than it needs to be now (the model
                # only has unfixed offsets) but less generic than it will need to be
                # if a rotation or magnification is added)
                if best_model[1] is not None:
                    initial = 2.5 / pixscale
                    for param in [getattr(m_init, p) for p in m_init.param_names]:
                        if 'offset' in param.name and not param.fixed:
                            param.bounds = (param.value - initial,
                                            param.value + initial)

                # First: estimate number of reference sources in field
                # Inverse map ref coords->image plane and see how many are in field
                xx, yy = m_init(xref, yref) if m_init else (xref, yref)
                x1, y1 = 0, 0
                y2, x2 = ad[index].data.shape
                # Could tweak y1, y2 here for GNIRS
                in_field = np.all((xx > x1 - initial, xx < x2 + initial,
                                   yy > y1 - initial, yy < y2 + initial), axis=0)
                num_ref_sources = np.sum(in_field)

                # We probably don't need zillions of REFCAT sources
                if max_ref_sources and num_ref_sources > max_ref_sources:
                    ref_mags = None
                    try:
                        ref_mags = refcat['filtermag']
                        if np.all(np.where(np.isnan(ref_mags), -999,
                                           ref_mags) < -99):
                            log.stdinfo('The REFCAT magnitude column has no '
                                        'valid values')
                            ref_mags = None
                    except KeyError:
                        log.stdinfo('Cannot find a magnitude column to cull REFCAT')
                    if ref_mags is None:
                        for filt in 'rhikgjzu':
                            try:
                                ref_mags = refcat[filt+'mag']
                            except KeyError:
                                pass
                            else:
                                if not np.all(np.where(np.isnan(ref_mags), -999,
                                                       ref_mags) < -99):
                                    log.stdinfo('Using {} magnitude instead'.
                                                format(filt))
                                    break

                    if ref_mags is not None:
                        in_field &= (ref_mags > -99)
                        num_ref_sources = np.sum(in_field)
                        if num_ref_sources > max_ref_sources:
                            sorted_args = np.argsort(ref_mags)
                            in_field = sorted_args[in_field[sorted_args]][:max_ref_sources]
                            log.stdinfo('Using only {} brightest REFCAT sources '
                                        'for speed'.format(max_ref_sources))
                            # in_field is now a list of indices, not a boolean array
                            num_ref_sources = len(in_field)

                # How many objects do we want to try to match? Keep brightest ones only
                if objcat_len > 2 * num_ref_sources:
                    keep_num = max(2 * num_ref_sources, min(10, objcat_len))
                else:
                    keep_num = objcat_len
                sorted_idx = np.argsort(objcat['MAG_AUTO'])[:keep_num]

                # Send all sources to the alignment/matching engine, indicating the ones to
                # use for the alignment
                if num_ref_sources > 0:
                    log.stdinfo('Aligning {}:{} with {} REFCAT and {} OBJCAT sources'.
                                format(ad.filename, extver, num_ref_sources, keep_num))
                    matched, m_final = match_catalogs(xref, yref, objcat['X_IMAGE'], objcat['Y_IMAGE'],
                                                      use_in=in_field, use_ref=sorted_idx,
                                                      model_guess=m_init, translation_range=initial,
                                                      tolerance=0.05, match_radius=final)
                else:
                    log.stdinfo('No REFCAT sources in field of extver {}'.format(extver))
                    continue

                num_matched = np.sum(matched >= 0)
                log.stdinfo("Matched {} objects in OBJCAT:{} against REFCAT".
                            format(num_matched, extver))
                # If this is a "better" match, save it
                # TODO? Some sort of averaging of models?
                if num_matched > max(best_model[0], 2):
                    best_model = (num_matched, m_final)

                if num_matched > 0:
                    # Update WCS in the header and OBJCAT (X_WORLD, Y_WORLD)
                    if full_wcs:
                        new_wcs = m_final.wcs
                    else:
                        kwargs = dict(zip(m_final.param_names, m_final.parameters))
                        new_wcs = Pix2Sky(wcs, **kwargs).wcs
                    _write_wcs_keywords(ext, new_wcs, self.keyword_comments)
                    objcat['X_WORLD'], objcat['Y_WORLD'] = new_wcs.all_pix2world(
                        objcat['X_IMAGE'], objcat['Y_IMAGE'], 1)

                    # Sky coordinates of original CRPIX location with old
                    # and new WCS (easier than using the transform)
                    ra0, dec0 = wcs.all_pix2world([wcs.wcs.crpix], 1)[0]
                    ra1, dec1 = new_wcs.all_pix2world([wcs.wcs.crpix], 1)[0]
                    cosdec = math.cos(math.radians(dec0))
                    delta_ra = 3600 * (ra1-ra0) * cosdec
                    delta_dec = 3600 * (dec1-dec0)
                    all_delta_ra.append(delta_ra)
                    all_delta_dec.append(delta_dec)

                    # Associate REFCAT properties with their OBJCAT
                    # counterparts. Remember! matched is the reference
                    # (OBJCAT) source for the input (REFCAT) source
                    dra = []
                    ddec = []
                    for i, m in enumerate(matched):
                        if m >= 0:
                            objcat['REF_NUMBER'][m] = refcat['Id'][i]
                            try:
                                objcat['REF_MAG'][m] = refcat['filtermag'][i]
                                objcat['REF_MAG_ERR'][m] = refcat['filtermag_err'][i]
                            except KeyError:  # no such columns in REFCAT
                                pass
                            dra.append(3600*(objcat['X_WORLD'][m] -
                                             refcat['RAJ2000'][i]) * cosdec)
                            ddec.append(2600*(objcat['Y_WORLD'][m] -
                                              refcat['DEJ2000'][i]))
                    dra_std = np.std(dra)
                    ddec_std = np.std(ddec)
                    log.fullinfo("WCS Updated for extver {}. Astrometric "
                                 "offset is:".format(extver))
                    log.fullinfo("RA:  {:.2f} +/- {:.2f} arcsec".
                                 format(delta_ra, dra_std))
                    log.fullinfo("Dec: {:.2f} +/- {:.2f} arcsec".
                                 format(delta_dec, ddec_std))
                    info_list.append({"dra": delta_ra, "dra_std": dra_std,
                                      "ddec": delta_dec, "ddec_std": ddec_std,
                                      "nsamples": int(num_matched)})
                else:
                    log.stdinfo("Could not determine astrometric offset for "
                                "{}:{}".format(ad.filename, extver))
                    info_list.append({})

            # Report the measurement to the fitsstore
            if self.upload and "metrics" in self.upload:
                fitsdict = qap.fitsstore_report(ad, "pe", info_list,
                                                self.calurl_dict,
                                                self.mode, upload=True)

            # Timestamp and update filename
            gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
            ad.update_filename(suffix=params["suffix"], strip=True)

        return adinputs
Ejemplo n.º 17
0
class IPHASdataClass:
	def __init__(self):
		print "Initialising an empty IPHAS data class"
		self.originalImageData = None
		self.boostedImage = None
		self.FITSHeaders = {}
		self.filter = None
		self.pixelScale = None
		self.centre = None
		self.filename = None
		self.rootname = "unknown"
		self.ignorecache = False
		self.catalogs = {}
		self.figSize = 8.
		self.previewSize = 4.
		self.magLimit = 18
		self.mask = None
		self.borderSize = 50
		self.superPixelSize = 50
		self.spacingLimit = 60./60.  # Minimum spacing of pointings in arcminutes
		self.rejectTooManyMaskedPixels = 0.70
		self.varianceThreshold = 5
		self.fullDebug = False
		self.objectStore = {}
		self.activeColour = 'r'
		return None
		
	def setProperty(self, property, value):
		truths = ["true", "yes", "on", "1", "Y", "y", "True"]
		falses = ["false", "no", "off", "0", "N", "n", "False"]
		if property=='maglimit':
			self.__dict__['magLimit'] = float(value)
		if property=="ignorecache":
			if value in truths:
				self.ignorecache = True
			if value in falses:
				self.ignorecache = False
		if property=='superpixelsize':
			self.__dict__['superPixelSize'] = int(value)
		if property=='spacinglimit':
			self.__dict__['spacingLimit'] = float(value)
		if property=='plotwindowsize':
			self.__dict__['figSize'] = float(value)
		if property=='debug':
			if value in truths:
				self.fullDebug = True
			if value in falses:
				self.fullDebug = False
		if property=='colour' or property=='color':
			self.__dict__['activeColour'] = str(value)

			
	def getStoredObject(self, name):
		try:
			return self.objectStore[name]
		except KeyError:
			print "Could not find an object called %s in internal object storage."%name
		return None	
		
	def loadFITSFile(self, filename):
		hdulist = fits.open(filename)
		self.filename = filename
		self.rootname = filename.split(".")[0]
		FITSHeaders = []
		for card in hdulist:
			# print(card.header.keys())
			# print(repr(card.header))
			for key in card.header.keys():
				self.FITSHeaders[key] = card.header[key]
				if 'WFFBAND' in key:
					self.filter = card.header[key]
		import astropy.io.fits as pf
		self.originalImageData = pf.getdata(filename, uint=False, do_not_scale_image_data=False)
		# self.originalImageData =  hdulist[1].data
		self.height, self.width = numpy.shape(self.originalImageData)
		self.wcsSolution = WCS(hdulist[1].header)
		print "width, height", self.width, self.height, "shape:", numpy.shape(self.originalImageData)
		self.getRADECmargins()
		imageCentre = (self.width/2, self.height/2)
		ra, dec = self.wcsSolution.all_pix2world([imageCentre], 1)[0]
		self.centre = (ra, dec)
		positionString = generalUtils.toSexagesimal((ra, dec))
		print "RA, DEC of image centre is: ", positionString, ra, dec
		
		hdulist.close()
		
	def showVizierCatalogs(self):
		(ra, dec) = self.centre
		from astroquery.vizier import Vizier
		Vizier.ROW_LIMIT = 50
		from astropy import coordinates
		from astropy import units as u
		c = coordinates.SkyCoord(ra,dec,unit=('deg','deg'),frame='icrs')
		skyHeight= coordinates.Angle(self.raRange, unit = u.deg)
		results = Vizier.query_region(coordinates = c, radius= 1.0 * u.deg)
		print results
		
		
	def getVizierObjects(self, catalogName):
		""" Make a request to Vizier to get an Astropy Table of catalog object for this field. """
		(ra, dec) = self.centre
		
		availableCatalogs = catalogMetadata.keys()
		if catalogName not in availableCatalogs:
			print "The definitions for this catalogue are unknown. Available catalogues are:", availableCatalogs
			return
		
		# First look for a cached copy of this data
		filenameParts = self.filename.split('.')
		catalogCache = filenameParts[0] + "_" + catalogName + "_cache.fits"
		cached = False
		if not self.ignorecache:
			print "Looking for a cached copy of the catalogue:", catalogCache, 
			if os.path.exists(catalogCache):
				print "FOUND"
				cached = True
			else: print "NOT FOUND"
	
		if cached:
			newCatalog = Table.read(catalogCache)
		else:			
			print "Going online to fetch %s results from Vizier with mag limit %f."%(catalogName, self.magLimit)
			from astroquery.vizier import Vizier
			Vizier.ROW_LIMIT = 1E5
			Vizier.column_filters={"r":"<%d"%self.magLimit}
			from astropy import coordinates
			from astropy import units as u
			c = coordinates.SkyCoord(ra,dec,unit=('deg','deg'),frame='icrs')
			skyRA  = coordinates.Angle(self.raRange, unit = u.deg)
			skyDEC = coordinates.Angle(self.decRange, unit = u.deg)
			print "Sky RA, DEC range:", skyRA, skyDEC
			print "going to Astroquery for:", catalogMetadata[catalogName]['VizierLookup']
			result = Vizier.query_region(coordinates = c, width = skyRA, height = skyDEC, catalog = catalogMetadata[catalogName]['VizierName'], verbose=True)
			print result
			newCatalog = result[catalogMetadata[catalogName]['VizierName']]
			newCatalog.pprint()
			
			
			# Write the new catalog to the cache file
			newCatalog.write(catalogCache, format='fits', overwrite=True)
		
		self.addCatalog(newCatalog, catalogName)
		
		return
		
		
	def printCatalog(self, catalogName):
		catalog = self.catalogs[catalogName]
		for b in catalog:
			print b
		print "%d rows printed."%len(catalog)
		
	def typeObject(self, objectName):
		try:
			objects = self.objectStore[objectName]
			for index, o in enumerate(objects):
				print index, ":", o
		except KeyError:
			print "Could not find an object called %s stored internally."%objectName
	
	def addCatalog(self, catTable, catalogName):
		newCatalog = []
		columnMapper = catalogMetadata[catalogName]['columns']
		for index, row in enumerate(catTable):
			object={}
			skipRow = False
			for key in columnMapper.keys():
				object[key] = row[columnMapper[key]]
				if numpy.isnan(row[columnMapper[key]]): skipRow = True
			if skipRow: continue		
			x, y = self.wcsSolution.all_world2pix([object['ra']], [object['dec']], 1)
			object['x'] = x[0]
			object['y'] = y[0]
			
			newCatalog.append(object)
			if  ((index+1)%100) == 0:
				sys.stdout.write("\rCopying: %d of %d."%(index+1, len(catTable)))
				sys.stdout.flush()
		sys.stdout.write("\rCopying: %d of %d.\n"%(index+1, len(catTable)))
		sys.stdout.flush()
					
		trimmedCatalog = []
		for row in newCatalog:
			if row['x']<0: continue
			if row['x']>self.width: continue
			if row['y']<0: continue
			if row['y']>self.height: continue
			trimmedCatalog.append(row)
		print "Rejected %d points for being outside of the CCD x, y pixel boundaries."%(len(newCatalog)-len(trimmedCatalog))
		newCatalog = trimmedCatalog

		print "Adding catalog %s to list of stored catalogs."%catalogName
		self.catalogs[catalogName] =  newCatalog
		
		return
				
	def getRADECmargins(self):
		boundingBox = self.wcsSolution.all_pix2world([[0, 0], [0, self.width], [self.height, self.width], [self.height, 0]], 1, ra_dec_order = True)
		# boundingBox = self.wcsSolution.all_pix2world([[0, 0], [0, self.height], [self.width, self.height], [self.width, 0]], 1, ra_dec_order = True)
		print "Bounding box:", boundingBox
		pixelDiagonal = math.sqrt(self.height**2 + self.width**2)
		pixel1 = boundingBox[0]
		pixel2 = boundingBox[2]
		skyDiagonal = distance(pixel1, pixel2)
		print "Diagonal size:", pixelDiagonal, skyDiagonal
		self.pixelScale = (skyDiagonal / pixelDiagonal) * 3600.
		raMin = numpy.min([r[0] for r in boundingBox])
		raMax = numpy.max([r[0] for r in boundingBox])
		decMin = numpy.min([r[1] for r in boundingBox])
		decMax = numpy.max([r[1] for r in boundingBox])
		print "RA, DEC min/max:", raMin, raMax, decMin, decMax
		raRange = raMax - raMin
		decRange = decMax - decMin
		print "RA range, DEC range", raRange, decRange, raRange*60, decRange*60
		self.raRange = raRange
		self.decRange = decRange
		print "Pixel scale: %6.4f \"/pixel"%self.pixelScale
		self.boundingBox = boundingBox
		
	def showFITSHeaders(self):
		headersString = ""
		for key in self.FITSHeaders.keys():
			print key + " : " + str(self.FITSHeaders[key])
			headersString+= str(key) + " : " + str(self.FITSHeaders[key]) + "\n"
		return headersString
			
	def getFITSHeader(self, key):
		try:
			print key + " : " + str(self.FITSHeaders[key]) 
			return self.FITSHeaders[key]
		except KeyError:
			print "Could not find a header with the name:", key
			return None 
			
	def plotCatalog(self, catalogName):
		try:
			catalog = self.catalogs[catalogName]
		except KeyError:
			print "Could not find a catalog called %s."%catalogName
			return
		
		catalogColour = catalogMetadata[catalogName]['colour']
		try:
			xArray = []
			yArray = []
			rArray = []
			for o in catalog:
				# Check that the catalog has a class flag
				if 'class' in o.keys():
					if o['class'] != -1: continue   # Skip objects that are not stars  
				xArray.append(o['x'] - 1)
				yArray.append(self.height - 1 - o['y'] )
				if catalogName=='dr2':
					r = o['pixelFWHM']*8.
				else:
					if o['mag']>12:
						r = 40*math.exp((-o['mag']+12)/4)
					else: 
						r = 40
				rArray.append(r)
	
			# Nick Wright 
			# R / pixels = 8192/M^2 + 1000/M + 100 
			matplotlib.pyplot.figure(self.figure.number)
			patches = [matplotlib.pyplot.Circle((x_, y_), s_, fill=False, linewidth=1) for x_, y_, s_ in numpy.broadcast(xArray, yArray, rArray)]
			collection = matplotlib.collections.PatchCollection(patches, alpha = 0.25, color = catalogColour)
			ax = matplotlib.pyplot.gca()
			ax.add_collection(collection)
			matplotlib.pyplot.draw()
			matplotlib.pyplot.show()
			matplotlib.pyplot.pause(0.01)
			# matplotlib.pyplot.savefig("test.png", bbox_inches='tight')
		except AttributeError as e:
			print "There is no drawing surface defined yet. Please use the 'draw' command first."
			print e
		except Exception as e:
			print e
			
			
	def drawMask(self):
		if self.mask is None:
			print "There is no mask defined yet."
			return
		self.maskFigure = matplotlib.pyplot.figure(self.filename + " mask", figsize=(self.figSize/1.618, self.figSize))
		self.maskFigure.frameon = False
		self.maskFigure.set_tight_layout(True)
		axes = matplotlib.pyplot.gca()
		axes.set_axis_off()
		self.maskFigure.add_axes(axes)
		imgplot = matplotlib.pyplot.imshow(numpy.flipud(self.mask), cmap="gray_r", interpolation='nearest')
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
		matplotlib.pyplot.pause(0.01)
		
		return
		
			
	def maskCatalog(self, catalogName):
		if self.mask is None:
			self.mask = numpy.zeros(numpy.shape(self.originalImageData))
			print "Creating a new blank mask of size:", numpy.shape(self.mask)

		# Mask the border areas
		if catalogName == 'border':
			border = self.borderSize
			self.mask[0:border, 0:self.width] = 132
			self.mask[self.height-border:self.height, 0:self.width] = 132
			self.mask[0:self.height, 0:border] = 132
			self.mask[0:self.height, self.width-border:self.width] = 132
			self.drawMask()
			return
			
		# Retrieve the catalogue
		try:
			catalog = self.catalogs[catalogName]
		except KeyError:
			print "Could not find a catalog called %s."%catalogName
			return
		

		xArray = []
		yArray = []
		rArray = []
		for o in catalog:
			# Check that the catalog has a class flag
			if 'class' in o.keys():
				if o['class'] != -1: continue   # Skip objects that are not stars  
			xArray.append(o['x'])
			yArray.append(o['y'])
			if catalogName=='dr2':
				r = o['pixelFWHM']*9.
			else:
				if o['mag']>12:
					r = 40*math.exp((-o['mag']+12)/4)
				else: 
					r = 50
			rArray.append(r)
			
		index = 1	
		for x, y, r in zip(xArray, yArray, rArray):
			self.mask = generalUtils.gridCircle(y, x, r, self.mask)
			sys.stdout.write("\rMasking: %d of %d."%(index, len(catalog)))
			sys.stdout.flush()
			index+= 1
		sys.stdout.write("\n")
		sys.stdout.flush()
	
		self.drawMask()
		
	def plotObject(self, objectName):
		objects = self.getStoredObject(objectName)
		
		colour = self.activeColour
		
		# Get the main plotting figure
		matplotlib.pyplot.figure(self.figure.number)
		
		for index, o in enumerate(objects):
			position = o.getPixelPosition()
			print position
			# matplotlib.pyplot.plot(o.x, o.y, color = 'r', marker='o', markersize=25, lw=4, fillstyle='none')
			xoffset = o.maxPosition[1]
			yoffset = self.superPixelSize - 2 - o.maxPosition[0]
			print "offsets", xoffset, yoffset
			matplotlib.pyplot.plot(o.x1 + xoffset, o.y1 + yoffset , color = colour, marker='o', markersize=15, mew=3, fillstyle='none')
			matplotlib.pyplot.annotate(str(index), (o.x1+xoffset+20, o.y1+yoffset), color=colour, fontweight='bold', fontsize=15)
			# if index==2: break
			
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
		matplotlib.pyplot.pause(0.01)
		return
		
		
	def drawPreview(self, pointingsName, index, title=None):
		if title is None:
			title = "Preview of pointing number %d in %s"%(index, pointingsName)
		print "Creating preview: %s"%title
		
		
		objectList = self.getStoredObject(pointingsName)
		if objectList is None: return
		
		pointingObject = objectList[index]
		
		print "mean: %f"%pointingObject.mean
		self.previewFigure = matplotlib.pyplot.figure(title, figsize=(self.previewSize, self.previewSize))
		self.previewFigure.frameon = False
		self.previewFigure.set_tight_layout(True)
		axes = matplotlib.pyplot.gca()
		axes.cla()
		axes.set_axis_off()
		self.previewFigure.add_axes(axes)
		imgplot = matplotlib.pyplot.imshow(numpy.flipud(pointingObject.data), cmap="hsv", interpolation='nearest')
		matplotlib.pyplot.plot(pointingObject.maxPosition[1], self.superPixelSize - 2 - pointingObject.maxPosition[0], color = 'r', marker='o', markersize=25, lw=4, fillstyle='none')
		matplotlib.pyplot.plot(10, 10, color = 'g', marker='x')
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
		matplotlib.pyplot.pause(0.01)
		
		print pointingObject.data
		print pointingObject.ra, pointingObject.dec, generalUtils.toSexagesimal((pointingObject.ra, pointingObject.dec))
		return
		
			
	def drawBitmap(self):
		if self.boostedImage is None:
			print "Boosting the image"
			self.boostedImage = generalUtils.percentiles(numpy.copy(self.originalImageData), 20, 99)
		matplotlib.pyplot.ion()
		# mplFrame = numpy.rot90(self.boostedImage)
		mplFrame = self.boostedImage
		mplFrame = numpy.flipud(mplFrame)
		self.figure = matplotlib.pyplot.figure(self.filename, figsize=(self.figSize/1.618, self.figSize))
		self.figure.frameon = False
		self.figure.set_tight_layout(True)
		axes = matplotlib.pyplot.gca()
		axes.set_axis_off()
		self.figure.add_axes(axes)
		imgplot = matplotlib.pyplot.imshow(mplFrame, cmap="gray_r", interpolation='nearest')
		
		verts = []
		for b in self.boundingBox:
			print b
			y, x = self.wcsSolution.all_world2pix(b[0], b[1], 1, ra_dec_order=True)
			coord  = (float(x), float(y))
			print coord
			verts.append(coord)	
			
		verts.append((0, 0))
			
		print verts
		codes = [Path.MOVETO,
		         Path.LINETO,
		         Path.LINETO,
		         Path.LINETO,
		         Path.CLOSEPOLY,
		         ]

		path = Path(verts, codes)

		patch = matplotlib.patches.PathPatch(path, fill=None, lw=2)
		axes.add_patch(patch)
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show(block=False)
		matplotlib.pyplot.draw()
		matplotlib.pyplot.pause(0.01)
		
		# matplotlib.pyplot.savefig("test.png",bbox_inches='tight')
		
	def applyMask(self):
		if self.mask is None:
			print "There is no mask defined. Define one with the 'mask' command."
			return
		
		if self.originalImageData is None:
			print "There is no source bitmap defined. Load one with the 'load' command."
			return
			
			
		booleanMask = numpy.ma.make_mask(numpy.flipud(self.mask))
		maskedImageData = numpy.ma.masked_array(self.originalImageData,  numpy.logical_not(booleanMask))
		
		self.maskedImage = maskedImageData
		
		matplotlib.pyplot.figure(self.figure.number)
		axes = matplotlib.pyplot.gca()
		imgplot = matplotlib.pyplot.imshow(self.maskedImage, cmap="gray_r", interpolation='nearest')
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
		matplotlib.pyplot.pause(0.01)

	def makeSuperPixels(self):
		superPixelList = []
		superPixelSize = self.superPixelSize
		borderMask = self.borderSize	
		width = self.width
		height = self.height
		
		# Draw the grid on the matplotlib panel
		matplotlib.pyplot.figure(self.figure.number)
		# axes = matplotlib.pyplot.gca()
		for yStep in range(borderMask, self.height-borderMask, superPixelSize):
			matplotlib.pyplot.plot([borderMask, self.width - borderMask], [yStep, yStep], ls=':', color='g', lw=2)
		for xStep in range(borderMask, self.width-borderMask, superPixelSize):
			matplotlib.pyplot.plot([xStep, xStep], [borderMask, self.height - borderMask], ls=':', color='r', lw=2)
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
		matplotlib.pyplot.pause(0.01)
		# End of drawing
		
		imageCopy = numpy.copy(self.originalImageData)
		booleanMask = numpy.ma.make_mask(self.mask)
		maskedImageCopy = numpy.ma.masked_array(imageCopy, booleanMask)
			
		pixelBitmapWidth = int((width - 2.*borderMask) / superPixelSize) + 1
		pixelBitmapHeight = int((height - 2.*borderMask) / superPixelSize) + 1
		pixelBitmap = numpy.zeros((pixelBitmapHeight, pixelBitmapWidth))
		pixelBitmap.fill(99E9) 
		
		rejectMaskCount = 0
		rejectVarCount = 0
		index = 0
		for yStep in range(borderMask, self.height-borderMask, superPixelSize):
			# matplotlib.pyplot.plot([borderMask, self.width - borderMask], [yStep, yStep], ls=':', color='g')
			for xStep in range(borderMask, self.width-borderMask, superPixelSize):
				"""index+=1
				if index>30: return
				"""
				x1 = xStep
				x2 = xStep + superPixelSize - 1
				y1 = yStep
				y2 = yStep + superPixelSize - 1
				# print xStep, yStep, x1, x2, y1, y2, self.height-y2, self.height-y1
				superPixel = maskedImageCopy[self.height-y2:self.height-y1, x1:x2]
				superPixelObject = {}
				mean = float(numpy.ma.mean(superPixel))
				if math.isnan(mean): continue;
				superPixelObject['mean'] = mean
				superPixelObject['median'] = numpy.ma.median(superPixel)
				superPixelObject['min'] = numpy.ma.min(superPixel)
				superPixelObject['max'] = numpy.ma.max(superPixel)
				superPixelObject['x1'] = x1
				superPixelObject['y1'] = y1
				superPixelObject['x2'] = x2
				superPixelObject['y2'] = y2
				superPixelObject['xc'] = x1 + superPixelSize/2.
				superPixelObject['yc'] = y1 + superPixelSize/2.
				superPixelObject['data'] = superPixel
				
				
				bitmapX = (x1-borderMask)/superPixelSize
				bitmapY = (y1-borderMask)/superPixelSize
				
				if self.fullDebug:
					matplotlib.pyplot.figure(self.figure.number)
					matplotlib.pyplot.plot(superPixelObject['xc'], superPixelObject['yc'], color = 'r', marker='x', markersize=25, lw=4, fillstyle='none')
					matplotlib.pyplot.draw()
					matplotlib.pyplot.show()
					matplotlib.pyplot.pause(0.001)
				
					self.previewFigure = matplotlib.pyplot.figure("Superpixel", figsize=(self.previewSize, self.previewSize))
					self.previewFigure.frameon = False
					self.previewFigure.set_tight_layout(True)
					axes = matplotlib.pyplot.gca()
					axes.cla()
					axes.set_axis_off()
					self.previewFigure.add_axes(axes)
					imgplot = matplotlib.pyplot.imshow(numpy.flipud(superPixelObject['data']), cmap="gray_r", interpolation='nearest')
					matplotlib.pyplot.draw()
					matplotlib.pyplot.show()
					matplotlib.pyplot.pause(0.001)
					print x1, x2, y1, y2, superPixelObject['mean'], superPixelObject['median']
					raw_input("Press Enter to continue...")
					

				variance = numpy.ma.var(superPixel)
				numPixels= numpy.ma.count(superPixel)
				superPixelObject['varppixel'] = variance/numPixels
				if superPixelObject['varppixel']>self.varianceThreshold: 
					rejectVarCount+= 1
					continue
				
				numMaskedPixels = numpy.ma.count_masked(superPixel)
				superPixelObject['maskedpixels'] = numMaskedPixels
				maskedRatio = float(numMaskedPixels)/float(numPixels)
				if maskedRatio>self.rejectTooManyMaskedPixels: 
					# print "too many masked pixels here. Rejecting."
					rejectMaskCount+=1
					continue;
				superPixelList.append(superPixelObject)
				pixelBitmap[bitmapY, bitmapX] = mean
				
		print "%d pixels rejected for having too many masked pixels. Masked pixel ratio > %2.2f%%"%(rejectMaskCount, self.rejectTooManyMaskedPixels)
		print "%d pixels rejected for having too large variance. Variance per pixel > %2.2f"%(rejectVarCount, self.varianceThreshold)
		
		self.sampledImageFigure = matplotlib.pyplot.figure("Sampled Image", figsize=(self.figSize/1.618, self.figSize))
		self.sampledImageFigure.frameon = False
		self.sampledImageFigure.set_tight_layout(True)
		axes = matplotlib.pyplot.gca()
		axes.cla()
		axes.set_axis_off()
		self.sampledImageFigure.add_axes(axes)
		
		maskedPixelImage = numpy.ma.masked_equal(pixelBitmap, 99E9)
		
		
		# minimumPixel = numpy.min(pixelBitmap)
		# pixelBitmap[pixelBitmap==99E9] = minimumPixel
		
		imgplot = matplotlib.pyplot.imshow(maskedPixelImage, cmap="hsv", interpolation='nearest')
		matplotlib.pyplot.draw()
		matplotlib.pyplot.show()
				
		
		self.superPixelList = superPixelList
		return 
		
	def getRankedPixels(self, number=50):
		# Top sources
		top = True
		if number<0:
			top = False
			number = abs(number)
			
		# Sort superpixels
		if top: self.superPixelList.sort(key=lambda x: x['mean'], reverse=True)
		else: self.superPixelList.sort(key=lambda x: x['mean'], reverse=False)
		
		
		
		pointings = []
		distanceLimitPixels = self.spacingLimit*60/self.pixelScale
		
		for index, s in enumerate(self.superPixelList):
			print index, s['mean'], s['varppixel'], s['xc'], s['yc']
			pointingObject = Pointing()
			pointingObject.length = self.superPixelSize
			pointingObject.x1 = s['x1']
			pointingObject.y1 = s['y1']
			pointingObject.x2 = s['x2']
			pointingObject.y2 = s['y2']
			
			pointingObject.x = s['xc']
			pointingObject.y = s['yc']
			pointingObject.mean = s['mean']
			pointingObject.varppixel = s['varppixel']
			pointingObject.data = s['data']
			if top: pointingObject.type = "Maximum"
			else: pointingObject.type = "Minimum"
			# Check if this is not near to an existing pointing
			reject = False
			for p in pointings:
				if distanceP(p, pointingObject) < distanceLimitPixels: 
					reject=True
					break
			if not reject: pointings.append(pointingObject)
			if len(pointings)>=number: break;
		
		# Compute the position of the max for each pointing and store it internally
		for p in pointings:
			p.computeMax()	
			p.computeAbsoluteLocation(self.wcsSolution)
		return pointings
		
	def clearFigure(self):
		""" Clears the main drawing window """
		print "Clearing the current figure."
		matplotlib.pyplot.figure(self.figure.number)
		matplotlib.pyplot.clf()
		return
		
	def dumpImage(self, filename):
		matplotlib.pyplot.figure(self.figure.number)
		filename = filename.format(root = self.rootname)
		extension = os.path.splitext(filename)[1]
		if not extension==".png":
			filename+= ".png" 
		matplotlib.pyplot.savefig(filename,bbox_inches='tight')
		return
		
	
	def dumpObject(self, objectName, filename, outputFormat):
		print "About to dump %s"%objectName
		objects = self.getStoredObject(objectName)
			
		filename = filename.format(root = self.rootname)
			
		if outputFormat=="json":
			objectList = [o.toJSON() for o in objects]
			outputFile = open(filename, "wt")
			outputFile.write(json.dumps(objectList))
			outputFile.close()
			return
		
		if outputFormat=="fits" or outputFormat=="votable":
			objectTable = Table()
			ids = []
			for index, o in enumerate(objects):
				id = self.rootname + "-%02d"%index
				if o.type=="Minimum": id = "sky-" + self.rootname + "-%02d"%index
				ids.append(id)
			objectTable['id'] = ids
			objectTable['ra'] = [o.ra for o in objects]
			objectTable['dec'] = [o.dec for o in objects]
			objectTable['xmax'] = [o.AbsoluteLocationPixels[0] for o in objects]
			objectTable['ymax'] = [o.AbsoluteLocationPixels[1] for o in objects]
			objectTable['mean'] = [o.mean for o in objects]
			objectTable['peak'] = [o.peak for o in objects]
			objectTable['variance'] = [o.varppixel for o in objects]
			objectTable['type'] = [o.type for o in objects]
			
			objectTable.write(filename, format='fits', overwrite=True)
			return
		
		
	def listPixels(self, number=0):
		for index, s in enumerate(self.superPixelList):
			print s['mean'], s['xc'], s['yc']
			print s
			if number!=0 and index==number:
				return
		return
		
		"""print "Original range:", numpy.min(self.originalImageData), numpy.max(self.originalImageData)
Ejemplo n.º 18
0
        n = overlap.sum()
        if n > 0:
            # New Will knots have velocity encoded in knot_id
            v0, dv = look_for_velocity(knot_id, line_id)
            if v0 is None:
                # But fall back on table look-up for the original Alba knots
                if not knot_id in tab["knot"]:
                    # If not there either, then skip this region
                    print("Warning: Knot", knot_id, "not found in table!")
                    continue
                # Extract row from table
                knotrow = tab[tab["knot"] == knot_id][0]
                v0, dv = knotrow[vcol[line_id]], knotrow[wcol[line_id]]
            if region_frame == "image":
                # Find j-pixel coordinates along the slit that correspond to this knot
                _, jslit, _ = specwcs.all_world2pix([0] * n, X[overlap], Y[overlap], 0)
                j1, j2 = jslit.min(), jslit.max()
                # Find i-pixel coordinates coresponding to knot velocity +/- width
                # Make sure to convert from km/s -> m/s since wcs is in SI
                v1, v2 = 1000 * (v0 - dv / 2), 1000 * (v0 + dv / 2)

                [i1, i2], _, _ = specwcs.all_world2pix([v1, v2], [0, 0], [0, 0], 0)
                i0, w = 0.5 * (i1 + i2), i2 - i1
                j0, h = 0.5 * (j1 + j2), j2 - j1
                pvregions.append([knot_id, i0, j0, w, h])
            elif region_frame == "linear":
                # Regions written in x = km/s and y = map X or Y,
                # depending on orientation
                S = X if orient == "horizontal" else Y
                s1, s2 = S[overlap].min(), S[overlap].max()
                s0, ds = 0.5 * (s1 + s2), (s2 - s1)
Ejemplo n.º 19
0
def get_cutout(ra, dec, angsize, survey, savepng=False, results=True, ids=[], filter=-1, path='./'):
    '''
    This function plots and retruns a cut-out
    :param ra: RA coordinate
    :param dec: DEC coordinate
    :param angsize: angular size of the image in arcsec
    :param survey: survey name
    :param savepng: True/False -- save a png image
    :param results: True/False -- return the images
    :param ids: ids of the sources (necessary to name an output file)
    :param filter: if -1 all filters are processed, if [...] -- only selected filters
    :param path: a path where data is stored
    :return: list of 2D images
    '''
    # for each survey, we define a list of files
    if survey=='AEGIS':
        prefix = path+'data/aegis/'
        files = [['IRAC1', 'AEGIS_CH1_SEDS_sci_reg.fits'],
                 ['IRAC2', 'AEGIS_CH2_SEDS_sci_reg.fits'],
                 ['F125W', 'aegis_3dhst.v4.0.F125W_orig_sci.fits.gz'],
                 ['F140W', 'aegis_3dhst.v4.0.F140W_orig_sci.fits.gz'],
                 ['F160W', 'aegis_3dhst.v4.0.F160W_orig_sci.fits.gz'],
                 ['F606W', 'aegis_3dhst.v4.0.F606W_orig_sci.fits.gz'],
                 ['F814W', 'aegis_3dhst.v4.0.F814W_orig_sci.fits.gz']]
    elif survey == 'COSMOS':
        prefix = path+'data/cosmos/'
        files = [['IRAC1', 'COSMOS_CH1_SEDS_sci_reg.fits'],
                 ['IRAC2', 'COSMOS_CH2_SEDS_sci_reg.fits'],
                 ['F125W', 'cosmos_3dhst.v4.0.F125W_orig_sci.fits.gz'],
                 ['F140W', 'cosmos_3dhst.v4.0.F140W_orig_sci.fits.gz'],
                 ['F160W', 'cosmos_3dhst.v4.0.F160W_orig_sci.fits.gz'],
                 ['F606W', 'cosmos_3dhst.v4.0.F606W_orig_sci.fits.gz'],
                 ['F814W', 'cosmos_3dhst.v4.0.F814W_orig_sci.fits.gz']]
    elif survey == 'GOODS-N':
        prefix = path+'data/goodsn/'
        files = [['IRAC1', 'GOODS-N_SEDS1_sci_sub_reg.fits'],
                 ['IRAC2', 'GOODS-N_SEDS2_sci_sub_reg.fits'],
                 ['F125W', 'goodsn_3dhst.v4.0.F125W_orig_sci.fits.gz'],
                 ['F140W', 'goodsn_3dhst.v4.0.F140W_orig_sci.fits.gz'],
                 ['F160W', 'goodsn_3dhst.v4.0.F160W_orig_sci.fits.gz'],
                 ['F435W', 'goodsn_3dhst.v4.0.F435W_orig_sci.fits.gz'],
                 ['F606W', 'goodsn_3dhst.v4.0.F606W_orig_sci.fits.gz'],
                 ['F775W', 'goodsn_3dhst.v4.0.F775W_orig_sci.fits.gz'],
                 ['F850LP', 'goodsn_3dhst.v4.0.F850LP_orig_sci.fits.gz']]
    elif survey=='GOODS-S':
        prefix = path+'data/goodss/'
        files = [['IRAC1', 'GOODS-S_SEDS1_sci_sub_reg.fits'],
                 ['IRAC2', 'GOODS-S_SEDS2_sci_sub_reg.fits'],
                 ['F125W', 'goodss_3dhst.v4.0.F125W_orig_sci.fits.gz'],
                 ['F140W', 'goodss_3dhst.v4.0.F140W_orig_sci.fits.gz'],
                 ['F160W', 'goodss_3dhst.v4.0.F160W_orig_sci.fits.gz'],
                 ['F435W', 'goodss_3dhst.v4.0.F435W_orig_sci.fits.gz'],
                 ['F606W', 'goodss_3dhst.v4.0.F606W_orig_sci.fits.gz'],
                 ['F775W', 'goodss_3dhst.v4.0.F775W_orig_sci.fits.gz'],
                 ['F850LP', 'goodss_3dhst.v4.0.F850LP_orig_sci.fits.gz']]
    elif survey == 'UDS':
        prefix = path+'data/uds/'
        files = [['IRAC1', 'UDS_SEDS1_sci_sub_reg.fits'],
                 ['IRAC2', 'UDS_SEDS2_sci_sub_reg.fits'],
                 ['F125W', 'uds_3dhst.v4.0.F125W_orig_sci.fits.gz'],
                 ['F140W', 'uds_3dhst.v4.0.F140W_orig_sci.fits.gz'],
                 ['F160W', 'uds_3dhst.v4.0.F160W_orig_sci.fits.gz'],
                 ['F606W', 'uds_3dhst.v4.0.F606W_orig_sci.fits.gz'],
                 ['F814W', 'uds_3dhst.v4.0.F814W_orig_sci.fits.gz']]
    else:
        print('ERROR!')
    res = []
    # if the picture is requested, initialize a figure
    if savepng:
        fig = plt.figure(0, figsize=(1, 1), dpi=200, frameon=False)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
    # loop through all files
    if filter==-1:
        filters = range(len(files))
    else:
        filters = filter
    for i in filters:
        print(files[i][0])
        w = WCS(prefix + files[i][1])
        data = pf.open(prefix + files[i][1])
        sizex = np.abs(1.0 * angsize / 3600 / data[0].header['CD2_2'])/2.0
        sizey = np.abs(1.0 * angsize / 3600 / data[0].header['CD1_1'])/2.0
        output = []
        # loop through all pointings
        for j in range(len(ra)):
            # converts the sky coordinates into pixel position
            temp = w.all_world2pix(ra[j], dec[j], 1, quiet=True)
            x = temp[1]  # [0]
            y = temp[0]  # [0]
            img = data[0].data[np.round(x)-sizex:np.round(x)+sizex, np.round(y)-sizey:np.round(y)+sizey]
            if results:
                output.append(img)
            # if the picture is requested, generate an image
            if savepng:
                vmax = img[sizex-5:sizex+5, sizey-5:sizey+5].max()
                plt.clf()
                ax = plt.Axes(fig, [0., 0., 1., 1.])
                ax.set_axis_off()
                fig.add_axes(ax)
                ax.imshow(img, interpolation='nearest', cmap='gray_r', aspect='normal', vmax=vmax)
                plt.savefig('output/'+survey+'/%06d_%2.2f_'%(ids[j], angsize)+files[i][0]+'.png')
        if results:
            res.append(output)
    return res
Ejemplo n.º 20
0
        wim = WCS('/Users/Cohn/Desktop/summer15research/masses with corrected foreground intensity ratios/filament_'+str(n)+'_mass_above_3_background_std_above_mean_mask.fits')
        
        lon_array=np.zeros((nrows_im,ncols_im))
        lat_array=np.zeros((nrows_im,ncols_im))
        for row in range(nrows_im):
            for col in range(ncols_im):
                lon, lat = wim.all_pix2world(col,row,0)
                lon_array[row,col]=lon
                lat_array[row,col]=lat

        x_array=np.zeros((nrows_im,ncols_im))
        y_array=np.zeros((nrows_im,ncols_im))
        for row in range(nrows_im):
            for col in range(ncols_im):
                if imdata[row,col]==1:
                    x, y = w.all_world2pix(lon_array[row,col],lat_array[row,col],0)
                    x_array[row,col] = x
                    y_array[row,col] = y

        data=lbdata*data
        new_map=np.zeros((nrows_dat,ncols_dat))
        total_col=0
        total_pix=0
        for row in range(nrows_im):
            for col in range(ncols_im):
                x=x_array[row,col]
                y=y_array[row,col]
                data[y,x]=np.nan_to_num(data[y,x])
                if data[y,x]>0:
                    new_map[y,x]=data[y,x]
                    total_col += data[y,x]
 def photometry(self, fits_file, radius):
     
     '''
     Perform photometry on an int type .fits file.
     
     Args:
         fits_file (str)  : File path of an int type .fits file
         radius    (float): Unitless radius of the desired aperture in kpc
     
     Returns:
         results (list): [supernova name (str),
                          exposure time (float)
                          photometry value (float),
                          photometry_error (float)]
         
         results (list): ["error" (str),
                          fits file path (str),
                          error description (str)]
     '''
     
     results = []
     
     if os.path.isfile(fits_file.replace("d-int", "d-skybg")):
         with fits.open(fits_file) as (int_file
                 ), fits.open(fits_file.replace("d-int", "d-skybg")) as (skybg_file
                 ):
 
             wcs = WCS(fits_file)
             for sn in self.cord_dict:
             
                 #Define the SN location in pixels
                 w = wcs.all_world2pix(self.cord_dict[sn].ra, self.cord_dict[sn].dec, 1)
                 
                 #Make sure the sn is located in the image
                 if 0 < w[0] < 3600 and 0 < w[1] < 3600:
                     #Find arcmin of a 1kpc radius region
                     r = radius * u.kpc / cosmo.kpc_comoving_per_arcmin(float(self.red_dict[sn]))
                     
                     #Create an aperture
                     aperture = SkyCircularAperture(self.cord_dict[sn], r)
                     
                     #create an array of the error in each pixel
                     exp_time = int_file[0].header["EXPTIME"]
                     int_error = np.sqrt(int_file[0].data / exp_time)
                     skybg_error = np.sqrt(skybg_file[0].data / exp_time)
                     
                     #Perform photometry
                     int_phot_table = aperture_photometry(int_file[0], aperture, error = int_error)
                     
                     if int_phot_table[0][0] == 0:
                         check = self.zero_check(fits_file, self.cord_dict[sn], r)
                         if check == 1:
                             skybg_phot_table = aperture_photometry(skybg_file[0], aperture, error = skybg_error)
                             
                             photometry_sum = int_phot_table[0][0] - skybg_phot_table[0][0]
                             photometry_error = np.sqrt(int_phot_table[0][1]**2 + skybg_phot_table[0][1]**2)
                             
                             results.append([sn, exp_time, photometry_sum, photometry_error])
                         
                         elif check == 0:
                             results.append(["error", fits_file, "no check file"])
                     
                     else:
                         skybg_phot_table = aperture_photometry(skybg_file[0], aperture, error = skybg_error)
                         
                         photometry_sum = int_phot_table[0][0] - skybg_phot_table[0][0]
                         photometry_error = np.sqrt(int_phot_table[0][1]**2 + skybg_phot_table[0][1]**2)
                         
                         results.append([sn, exp_time, photometry_sum, photometry_error])
         
         if results == []:
             results.append(["error", fits_file, "no supernova found"])
             return(results)
         
         else:
             return(results)
         
     else:
         results.append(["error", fits_file, "no skybg file"])
         return(results)

# The next steps should be probably done with the butler but I wanted
# to allow the users work without the DM stack for now

# We read the L2 output
mytab = astropy.table.Table.read(o.input_catalog)
# We read the calibrated exposure to obtain the zeropoints
hdulist = fits.open(o.calexp)
# We get the hdu containing the image which we will use later
reference = hdulist[1]
# We get the WCS on the calibrated exposure to calculate the positions on the chip of the input sources
w = WCS(reference.header)
#xcent, ycent = w.all_world2pix(w.wcs.crval[0],w.wcs.crval[1],0.,ra_dec_order=True)
#print xcent, ycent, w.wcs.crval[0], w.wcs.crval[1], dataframe['ra'][0], dataframe['dec'][0], ' Central pixels'
x, y = w.all_world2pix(dataframe['ra']+dataframe['delta_ra'],dataframe['dec']+dataframe['delta_dec'],0.,ra_dec_order=True)
# We get the zeropoints from the calibrated exposure
zeropoint = 2.5 * np.log10(hdulist[0].header["FLUXMAG0"])
# Calculating the magnitudes
psfmag_tot = zeropoint - 2.5 * np.log10(mytab['base_PsfFlux_flux'])
# Comparing to the truth table 
catsim = astropy.table.Table.read(o.mag_truth)

# Selecting stars and galaxies using the truth table
star_sel = dataframe['sed_name'].str.contains('star')
gal_sel = dataframe['sed_name'].str.contains('galaxy')

# Some useful numbers

print 'Number of stars ', np.count_nonzero(star_sel.values)
print 'Number of galaxies ', np.count_nonzero(gal_sel.values)
Ejemplo n.º 23
0
def plot_composites(pdata,idx_plot,outfolder,contours,contour_colors=True,calibration_plot=True):

	### open figure
	#fig, ax = plt.subplots(5,2, figsize=(7, 15))
	#fig, ax = plt.subplots(1,1, figsize=(7, 7),subplot_kw={'projection': ccrs.PlateCarree()})
	#ax = np.ravel(ax)

	### image qualities
	fs = 10
	maxlim = 0.05

	### filters
	#filters = ['SDSS u','SDSS g','SDSS i']
	#fcolors = ['Blues','Greens','Reds']
	#ftext = ['blue','green','red']
	filters = ['SDSS i']
	fcolors = ['Greys']
	ftext = ['black']

	### contour color limits (customized for W1-W2)
	color_limits = [-1.0,2.6]

	kernel = None

	### begin loop
	for ii,idx in enumerate(idx_plot):

		### load object information
		objname = pdata['objname'][idx]
		fagn = pdata['pars']['fagn']['q50'][idx]
		ra, dec = load_coordinates(objname)
		phot_size = load_structure(objname,long_axis=True) # in arcseconds

		### set up figure
		fig, ax = None, None
		xs, ys, dely = 0.05,0.9, 0.07

		for kk,filt in enumerate(filters):
			hdu = load_image(objname,filt)

			#### if it's the first filter,
			#### set up WCS using this information
			if fig == None:

				### grab WCS information, create figure + axis
				wcs = WCS(hdu.header)
				fig, ax = plt.subplots(2,3, figsize=(18, 18))
				plt.subplots_adjust(top=0.95,bottom=0.33)
				sedax = fig.add_axes([0.3,0.05,0.4,0.25])
				ax = np.ravel(ax)

				### translate object location into pixels using WCS coordinates
				pix_center = wcs.all_world2pix([[ra[0],dec[0]]],1)
				size = calc_dist(wcs, pix_center, phot_size, hdu.data.shape)
				hdu_original = copy.deepcopy(hdu.header)
				data_to_plot = hdu.data

				### build image extents
				# first calculate pixel location of image left, image bottom
				center_pix = np.atleast_2d([(size[0]+size[1])/2.,(size[2]+size[3])/2.])
				center_left_pix = np.atleast_2d([size[0],center_pix[0][1]])
				center_bottom_pix = np.atleast_2d([center_pix[0][0],size[2]])

				# now wcs location
				center_left_wcs = wcs.all_pix2world(center_left_pix,0)
				center_bottom_wcs = wcs.all_pix2world(center_bottom_pix,0)
				center_wcs = wcs.all_pix2world(center_pix,0)

				# now calculate distance
				center = SkyCoord(ra=center_wcs[0][0]*u.degree,dec=center_wcs[0][1]*u.degree)
				center_left = SkyCoord(ra=center_left_wcs[0][0]*u.degree,dec=center_left_wcs[0][1]*u.degree)
				center_bottom = SkyCoord(ra=center_bottom_wcs[0][0]*u.degree,dec=center_bottom_wcs[0][1]*u.degree)
				ydist = center.separation(center_bottom).arcsec
				xdist = center.separation(center_left).arcsec

				extent = [-xdist,xdist,-ydist,ydist]

			#### if it's not the first filter,
			#### project into WCS of first filter
			# see reprojection https://reproject.readthedocs.io/en/stable/
			else:
				data_to_plot, footprint = reproject_exact(hdu, hdu_original)

			plot_image(ax[5],data_to_plot[size[2]:size[3],size[0]:size[1]],size,cmap=fcolors[kk],extent=extent)
			ax[5].text(xs, ys, filters[kk]+'-band',color=ftext[kk],transform=ax[5].transAxes)
			ys -= dely

			### draw 6" line
			wise_psf = 6 # in arcseconds
			start = -0.85*xdist
			ax[5].plot([start,start+wise_psf],[start,start],lw=2,color='k')
			ax[5].text(start+wise_psf/2.,start+1, '6"', ha='center')
			ax[5].set_xlim(-xdist,xdist) # reset plot limits b/c of text stuff
			ax[5].set_ylim(-ydist,ydist)

		ax[5].set_xlabel('arcseconds')
		ax[5].set_ylabel('arcseconds')

		#### load up HDU, subtract background and convert to physical units
		# also convolve to W2 resolution
		hdu = load_image(objname,contours[0])
		hdu.data *= 1.9350E-06 ### convert from DN to flux in Janskies, from this table: http://wise2.ipac.caltech.edu/docs/release/allsky/expsup/sec2_3f.html
		hdu.data -= np.median(hdu.data) ### subtract background as median
		data1_noconv, footprint = reproject_exact(hdu, hdu_original)
		data_convolved, kernel = match_resolution(hdu.data,contours[0],contours[1],kernel=kernel,data1_res=hdu.header['PXSCAL1']) # convolve to W2 resolution
		hdu.data = data_convolved
		data1, footprint = reproject_exact(hdu, hdu_original)

		### load up HDU2, subtract background, convert to physical units
		hdu = load_image(objname,contours[1])
		hdu.data -= np.median(hdu.data) ### subtract background as median
		hdu.data *= 2.7048E-06 ### convert from DN to flux in Janskies, from this table: http://wise2.ipac.caltech.edu/docs/release/allsky/expsup/sec2_3f.html

		#### put onto same scale
		data2, footprint = reproject_exact(hdu, hdu_original)

		### plot the main result
		data1_slice = data1[size[2]:size[3],size[0]:size[1]]
		data2_slice = data2[size[2]:size[3],size[0]:size[1]]
		plot_color_contour(ax[5],data1_slice, data2_slice, contours[0],contours[1], maxlim=maxlim, color_limits=color_limits)

		#ax[5].text(xs, ys, 'contours:' +contours[0]+'-'+contours[1],transform=ax[5].transAxes)
		ys -= dely

		### labels and limits
		ax[5].text(0.98,0.93,objname,transform=ax[5].transAxes,ha='right')
		ax[5].text(0.98,0.88,r'f$_{\mathrm{AGN,MIR}}$='+"{:.2f}".format(fagn),transform=ax[5].transAxes, ha='right')
		ax[5].set_title('WISE colors on\nSDSS imaging')

		#### CALIBRATION PLOT
		flux_color = convert_to_color(data1_slice, data2_slice,contours[0],contours[1],minflux=1e-10)

		img = ax[0].imshow(data1_noconv[size[2]:size[3],size[0]:size[1]], origin='lower',extent=extent)
		cbar = fig.colorbar(img, ax=ax[0])
		cbar.formatter.set_powerlimits((0, 0))
		cbar.update_ticks()
		ax[0].set_title(contours[0]+', \n raw')

		img = ax[1].imshow(data1_slice, origin='lower',extent=extent)
		cbar = fig.colorbar(img, ax=ax[1])
		cbar.formatter.set_powerlimits((0, 0))
		cbar.update_ticks()
		ax[1].set_title(contours[0]+', \n convolved to W2 PSF')

		img = ax[2].imshow(data2_slice, origin='lower',extent=extent)
		cbar = fig.colorbar(img, ax=ax[2])
		cbar.formatter.set_powerlimits((0, 0))
		cbar.update_ticks()
		ax[2].set_title(contours[1]+', \n raw')

		img = ax[3].imshow(flux_color, origin='lower',extent=extent,vmin=color_limits[0],vmax=color_limits[1])
		cbar = fig.colorbar(img, ax=ax[3])
		ax[3].set_title(contours[0]+'-'+contours[1]+', \n raw')

		### don't trust anything less than X times the max!
		max1 = np.nanmax(data1_slice)
		max2 = np.nanmax(data2_slice)
		background = (data1_slice < max1*maxlim) | (data2_slice < max2*maxlim)
		flux_color[background] = np.nan


		img = ax[4].imshow(flux_color, origin='lower',extent=extent,vmin=color_limits[0],vmax=color_limits[1])
		cbar = fig.colorbar(img, ax=ax[4])
		cbar.formatter.set_powerlimits((0, 0))
		cbar.update_ticks()
		ax[4].set_title(contours[0]+'-'+contours[1]+', \n background removed')

		ax[4].plot([start,start+wise_psf],[start,start],lw=2,color='k')
		ax[4].text(start+wise_psf/2.,start+1, '6"', ha='center')
		ax[4].set_xlim(-xdist,xdist) # reset plot limits b/c of text stuff
		ax[4].set_ylim(-ydist,ydist)

		#### now plot the SED
		agn_color, noagn_color = '#FF3D0D', '#1C86EE'
		wavlims = (1,30)
		wav_idx = (pdata['observables']['wave'][ii]/1e4 > wavlims[0]) & (pdata['observables']['wave'][ii]/1e4 < wavlims[1])
		sedax.plot(pdata['observables']['wave'][ii][wav_idx]/1e4,pdata['observables']['agn_on_spec'][ii][wav_idx], lw=2.5, alpha=0.5, color=agn_color)
		sedax.plot(pdata['observables']['wave'][ii][wav_idx]/1e4,pdata['observables']['agn_off_spec'][ii][wav_idx], lw=2.5, alpha=0.5, color=noagn_color)
		if type(pdata['observables']['spit_lam'][ii]) is np.ndarray:
			wav_idx = (pdata['observables']['spit_lam'][ii]/1e4 > wavlims[0]) & (pdata['observables']['spit_lam'][ii]/1e4 < wavlims[1])
			sedax.plot(pdata['observables']['spit_lam'][ii][wav_idx]/1e4,pdata['observables']['spit_flux'][ii][wav_idx], lw=2.5, alpha=0.5, color='black')
		if type(pdata['observables']['ak_lam'][ii]) is np.ndarray:
			wav_idx = (pdata['observables']['ak_lam'][ii]/1e4 > wavlims[0]) & (pdata['observables']['ak_lam'][ii]/1e4 < wavlims[1])
			sedax.plot(pdata['observables']['ak_lam'][ii][wav_idx]/1e4,pdata['observables']['ak_flux'][ii][wav_idx], lw=2.5, alpha=0.5, color='black')

		### write down Vega colors
		sedax.text(0.95,0.1,'W1-W2(AGN ON)='+'{:.2f}'.format(pdata['observables']['agn_on_mag'][ii]),transform=sedax.transAxes,color=agn_color,ha='right')
		sedax.text(0.95,0.16,'W1-W2(AGN OFF)='+'{:.2f}'.format(pdata['observables']['agn_off_mag'][ii]),transform=sedax.transAxes,color=noagn_color,ha='right')
		sedax.text(0.95,0.22,'W1-W2(OBS)='+'{:.2f}'.format(pdata['observables']['obs_mag'][ii]),transform=sedax.transAxes,color='black',ha='right')

		lsfr = pdata['lsfr'][idx]
		if lsfr > 0:
			sedax.text(1.15,0.5,r'L$_{\mathrm{X}}$(obs)/L$_{\mathrm{X}}$(SFR)='+'{:.2f}'.format(lsfr),transform=sedax.transAxes,color='black',fontsize=18,weight='bold')
		else: 
			sedax.text(1.15,0.5,r'No X-ray information',transform=sedax.transAxes,color='black',fontsize=18,weight='bold')
		bpt = pdata['bpt'][idx]
		if bpt == 'None':
			sedax.text(1.15,0.42,'No BPT measurement',transform=sedax.transAxes,color='black',fontsize=18,weight='bold')
		else: 
			sedax.text(1.15,0.42,'BPT: '+bpt,transform=sedax.transAxes,color='black',fontsize=18,weight='bold')

		### scaling and labels
		sedax.set_yscale('log',nonposx='clip',subsx=(1,2,4))
		sedax.set_xscale('log',nonposx='clip',subsx=(1,2,4))
		sedax.xaxis.set_minor_formatter(minorFormatter)
		sedax.xaxis.set_major_formatter(majorFormatter)

		sedax.set_xlabel(r'wavelength $\mu$m')
		sedax.set_ylabel(r'f$_{\nu}$')

		sedax.axvline(3.4, linestyle='--', color='0.5',lw=1.5,alpha=0.8,zorder=-5)
		sedax.axvline(4.6, linestyle='--', color='0.5',lw=1.5,alpha=0.8,zorder=-5)

		sedax.set_xlim(wavlims)

		padding = ''
		if ii <= 9:
			padding='0'

		plt.savefig(outfolder+'/'+padding+str(ii)+'_'+objname+'.png',dpi=150)
		plt.close()
Ejemplo n.º 24
0
    def resampleToCommonFrame(self, adinputs=None, **params):
        """
        This primitive applies the transformation encoded in the input images
        WCSs to align them with a reference image, in reference image pixel
        coordinates. The reference image is taken to be the first image in
        the input list.
        
        By default, the transformation into the reference frame is done via
        interpolation. The interpolator parameter specifies the interpolation 
        method. The options are nearest-neighbor, bilinear, or nth-order 
        spline, with n = 2, 3, 4, or 5. If interpolator is None, 
        no interpolation is done: the input image is shifted by an integer
        number of pixels, such that the center of the frame matches up as
        well as possible. The variance plane, if present, is transformed in
        the same way as the science data.
        
        The data quality plane, if present, must be handled a little
        differently. DQ flags are set bit-wise, such that each pixel is the 
        sum of any of the following values: 0=good pixel,
        1=bad pixel (from bad pixel mask), 2=nonlinear, 4=saturated, etc.
        To transform the DQ plane without losing flag information, it is
        unpacked into separate masks, each of which is transformed in the same
        way as the science data. A pixel is flagged if it had greater than
        1% influence from a bad pixel. The transformed masks are then added
        back together to generate the transformed DQ plane.
        
        In order not to lose any data, the output image arrays (including the
        reference image's) are expanded with respect to the input image arrays.
        The science and variance data arrays are padded with zeros; the DQ
        plane is padded with 16s.
        
        The WCS keywords in the headers of the output images are updated
        to reflect the transformation.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        interpolator: str
            desired interpolation [nearest | linear | spline2 | spline3 |
                                   spline4 | spline5]
        trim_data: bool
            trim image to size of reference image?
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        interpolator = params["interpolator"]
        trim_data = params["trim_data"]
        sfx = params["suffix"]

        if len(adinputs) < 2:
            log.warning("No alignment will be performed, since at least two "
                        "input AstroData objects are required for "
                        "resampleToCommonFrame")
            return adinputs

        if not all(len(ad)==1 for ad in adinputs):
            raise IOError("All input images must have only one extension.")

        # --------------------  BEGIN establish reference frame  -------------------
        ref_image = adinputs[0]
        ref_wcs = WCS(ref_image[0].hdr)
        ref_shape = ref_image[0].data.shape
        ref_corners = at.get_corners(ref_shape)
        naxis = len(ref_shape)

        # first pass: get output image shape required to fit all
        # data in output by transforming corner coordinates of images
        all_corners = [ref_corners]
        corner_values = _transform_corners(adinputs[1:], all_corners, ref_wcs,
                                           interpolator)
        all_corners, xy_img_corners, shifts = corner_values
        refoff, out_shape = _shifts_and_shapes(all_corners, ref_shape, naxis,
                                               interpolator, trim_data, shifts)
        ref_corners = [(corner[1] - refoff[1] + 1, corner[0] - refoff[0] + 1) # x,y
                       for corner in ref_corners]
        area_keys = _build_area_keys(ref_corners)

        ref_image.hdr.set('CRPIX1', ref_wcs.wcs.crpix[0]-refoff[1],
                          self.keyword_comments["CRPIX1"])
        ref_image.hdr.set('CRPIX2', ref_wcs.wcs.crpix[1]-refoff[0],
                          self.keyword_comments["CRPIX2"])
        padding = tuple((int(-cen),out-int(ref-cen)) for cen, out, ref in
                        zip(refoff, out_shape, ref_shape))
        _pad_image(ref_image, padding)

        for key in area_keys:
            ref_image[0].hdr.set(*key)
        out_wcs = WCS(ref_image[0].hdr)
        ref_image.update_filename(suffix=sfx, strip=True)
        # -------------------- END establish reference frame -----------------------

        # --------------------   BEGIN transform data ...  -------------------------
        for ad, corners in zip(adinputs[1:], xy_img_corners):
            if interpolator:
                trans_parameters = _composite_transformation_matrix(ad,
                                        out_wcs, self.keyword_comments)
                matrix, matrix_det, img_wcs, offset = trans_parameters
            else:
                shift = _composite_from_ref_wcs(ad, out_wcs,
                                                self.keyword_comments)
                matrix_det = 1.0

            # transform corners to find new location of original data
            data_corners = out_wcs.all_world2pix(
                img_wcs.all_pix2world(corners, 0), 1)
            area_keys = _build_area_keys(data_corners)

            if interpolator:
                kwargs = {'matrix': matrix, 'offset': offset,
                          'order': interpolators[interpolator],
                          'output_shape': out_shape}
                new_var = None if ad[0].variance is None else \
                    affine_transform(ad[0].variance, cval=0.0, **kwargs)
                new_mask = _transform_mask(ad[0].mask if ad[0].mask is not None
                        else np.zeros_like(ad[0].data, dtype=DQ.datatype), **kwargs)
                if hasattr(ad[0], 'OBJMASK'):
                    ad[0].OBJMASK = _transform_mask(ad[0].OBJMASK, **kwargs)
                ad[0].reset(affine_transform(ad[0].data, cval=0.0, **kwargs),
                            new_mask, new_var)
            else:
                padding = tuple((int(-s), out-int(img-s)) for s, out, img in
                                zip(shift, out_shape, ad[0].data.shape))
                _pad_image(ad, padding)

            if abs(1.0 - matrix_det) > 1e-6:
                    log.fullinfo("Multiplying by {} to conserve flux".format(matrix_det))
                    # Allow the arith toolbox to do the multiplication
                    # so that variance is handled correctly
                    ad.multiply(matrix_det)

            for key in area_keys:
                ref_image[0].hdr.set(*key)

            # Timestamp and update filename
            ad.update_filename(suffix=sfx, strip=True)
        return adinputs
Ejemplo n.º 25
0
def GetHiResImage(ID):
    '''
    Queries the Palomar Observatory Sky Survey II catalog to
    obtain a higher resolution optical image of the star with EPIC number
    :py:obj:`ID`.

    '''

    # Get the TPF info
    client = kplr.API()
    star = client.k2_star(ID)
    k2ra = star.k2_ra
    k2dec = star.k2_dec
    tpf = star.get_target_pixel_files()[0]
    with tpf.open() as f:
        k2wcs = WCS(f[2].header)
        shape = np.array(f[1].data.field('FLUX'), dtype='float64')[0].shape

    # Get the POSS URL
    hou = int(k2ra * 24 / 360.)
    min = int(60 * (k2ra * 24 / 360. - hou))
    sec = 60 * (60 * (k2ra * 24 / 360. - hou) - min)
    ra = '%02d+%02d+%.2f' % (hou, min, sec)
    sgn = '' if np.sign(k2dec) >= 0 else '-'
    deg = int(np.abs(k2dec))
    min = int(60 * (np.abs(k2dec) - deg))
    sec = 3600 * (np.abs(k2dec) - deg - min / 60)
    dec = '%s%02d+%02d+%.1f' % (sgn, deg, min, sec)
    url = 'https://archive.stsci.edu/cgi-bin/dss_search?v=poss2ukstu_red&' + \
          'r=%s&d=%s&e=J2000&h=3&w=3&f=fits&c=none&fov=NONE&v3=' % (ra, dec)

    # Query the server
    r = urllib.request.Request(url)
    handler = urllib.request.urlopen(r)
    code = handler.getcode()
    if int(code) != 200:
        # Unavailable
        return None
    data = handler.read()

    # Atomically write to a temp file
    f = NamedTemporaryFile("wb", delete=False)
    f.write(data)
    f.flush()
    os.fsync(f.fileno())
    f.close()

    # Now open the POSS fits file
    with pyfits.open(f.name) as ff:
        img = ff[0].data

    # Map POSS pixels onto K2 pixels
    xy = np.empty((img.shape[0] * img.shape[1], 2))
    z = np.empty(img.shape[0] * img.shape[1])
    pwcs = WCS(f.name)
    k = 0
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            ra, dec = pwcs.all_pix2world(float(j), float(i), 0)
            xy[k] = k2wcs.all_world2pix(ra, dec, 0)
            z[k] = img[i, j]
            k += 1

    # Resample
    grid_x, grid_y = np.mgrid[-0.5:shape[1] - 0.5:0.1, -0.5:shape[0] - 0.5:0.1]
    resampled = griddata(xy, z, (grid_x, grid_y), method='cubic')

    # Rotate to align with K2 image. Not sure why, but it is necessary
    resampled = np.rot90(resampled)

    return resampled
Ejemplo n.º 26
0
def pick_positions(catalog, filename, separation, refimage=None, wcs_origin=1):
    """
    Assigns positions to fake star list generated by pick_models

    INPUTS:
    -------

    filename: string
        Name of AST list generated by pick_models

    separation: float
        Minimum pixel separation between AST and star in photometry
        catalog provided in the datamodel.

    refimage: string
        Name of the reference image.  If supplied, the method will use the
        reference image header to convert from RA and DEC to X and Y.

    wcs_origin : 0 or 1 (default=1)
        As described in the WCS documentation: "the coordinate in the upper
        left corner of the image. In FITS and Fortran standards, this is 1.
        In Numpy and C standards this is 0."

    OUTPUTS:
    --------

    Ascii table that replaces [filename] with a new version of
    [filename] that contains the necessary position columns for running
    the ASTs though DOLPHOT
    """

    noise = 3.0  # Spreads the ASTs in a circular annulus of 3 pixel width instead of all being
    # precisely [separation] from an observed star.

    colnames = catalog.data.columns

    if "X" or "x" in colnames:
        if "X" in colnames:
            x_positions = catalog.data["X"][:]
            y_positions = catalog.data["Y"][:]
        if "x" in colnames:
            x_positions = catalog.data["x"][:]
            y_positions = catalog.data["y"][:]
    else:
        if refimage:
            if ("RA" in colnames) or ("ra" in colnames):
                if "RA" in colnames:
                    ra_positions = catalog.data["RA"][:]
                    dec_positions = catalog.data["DEC"][:]
                if "ra" in colnames:
                    ra_positions = catalog.data["ra"][:]
                    dec_positions = catalog.data["dec"][:]
            else:
                raise RuntimeError(
                    "Your catalog does not supply X, Y or RA, DEC information for spatial AST distribution"
                )

        else:
            raise RuntimeError(
                "You must supply a Reference Image to determine spatial AST distribution."
            )
        wcs = WCS(refimage)

        x_positions, y_positions = wcs.all_world2pix(
            ra_positions, dec_positions, wcs_origin
        )

    astmags = ascii.read(filename)

    n_asts = len(astmags)

    # keep is defined to ensure that no fake stars are put outside of the image boundaries

    keep = (
        (x_positions > np.min(x_positions) + separation + noise)
        & (x_positions < np.max(x_positions) - separation - noise)
        & (y_positions > np.min(y_positions) + separation + noise)
        & (y_positions < np.max(y_positions) - separation - noise)
    )

    x_positions = x_positions[keep]
    y_positions = y_positions[keep]

    ncat = len(x_positions)
    ind = np.random.random(n_asts) * ncat
    ind = ind.astype("int")

    # Here we generate the circular distribution of ASTs surrounding random observed stars

    separation = np.random.random(n_asts) * noise + separation
    theta = np.random.random(n_asts) * 2.0 * np.pi
    xvar = separation * np.cos(theta)
    yvar = separation * np.sin(theta)

    new_x = x_positions[ind] + xvar
    new_y = y_positions[ind] + yvar
    column1 = 0 * new_x
    column2 = column1 + 1
    column1 = Column(name="zeros", data=column1.astype("int"))
    column2 = Column(name="ones", data=column2.astype("int"))
    column3 = Column(name="X", data=new_x, format="%.2f")
    column4 = Column(name="Y", data=new_y, format="%.2f")
    astmags.add_column(column1, 0)
    astmags.add_column(column2, 1)
    astmags.add_column(column3, 2)
    astmags.add_column(column4, 3)

    ascii.write(astmags, filename, overwrite=True)
Ejemplo n.º 27
0
    print ("Length of the dr2 table: %d" % len(dr2nearby))

    # Move table into a dictionary object
    dr2Objects = []
    for index, row in enumerate(dr2nearby):
        IPHAS2name = row["IPHAS2"]
        dr2Object = {}
        dr2Object["name"] = IPHAS2name
        dr2Object["ra"] = row["RAJ2000"]
        dr2Object["dec"] = row["DEJ2000"]
        dr2Object["class"] = row["mergedClass"]
        dr2Object["pStar"] = row["pStar"]
        dr2Object["iClass"] = row["iClass"]
        dr2Object["haClass"] = row["haClass"]
        dr2Object["pixelFWHM"] = row["haSeeing"] / pixelScale
        x, y = wcsSolution.all_world2pix([dr2Object["ra"]], [dr2Object["dec"]], 1)
        dr2Object["x"] = x[0]
        dr2Object["y"] = y[0]
        dr2Objects.append(dr2Object)
        if (index % 100) == 0:
            sys.stdout.write("\rCopying: %d of %d." % (index, len(dr2nearby)))
            sys.stdout.flush()
    sys.stdout.write("\n")
    sys.stdout.flush()

    # Run through all objects and find the ones that are haClass = "+1" but iClass != "+1" and overall class = "+1"
    extendedHaSources = []
    for index, d in enumerate(dr2Objects):
        if (d["class"] == 1) and (d["haClass"] == 1) and (d["iClass"] != 1):
            extendedHaSources.append(index)
Ejemplo n.º 28
0
def plot_composites(pdata,outfolder,contours,contour_colors=True,
                    calibration_plot=True,brown_data=False,paperplot=False):

    ### image qualities
    fs = 10 # fontsize
    maxlim = 0.01 # limit of maximum

    ### contour color limits (customized for W1-W2)
    color_limits = [-1.0,2.6]
    kernel = None

    ### output blobs
    gradient, gradient_error, arcsec,objname_out, obj_size_kpc, background_out = [], [], [], [], [], []

    ### begin loop
    fig = None
    for idx in xrange(len(pdata['objname'])):

        if paperplot:
            if ('NGC 4168' not in pdata['objname'][idx]) & ('NGC 1275' not in pdata['objname'][idx]):
                continue

        ### load object information
        objname = pdata['objname'][idx]
        fagn = pdata['pars']['fagn']['q50'][idx]
        ra, dec = load_coordinates(objname)
        phot_size = load_structure(objname,long_axis=True) # in arcseconds

        ### load image and WCS
        try:
            if brown_data:
                ### convert from DN to flux in Janskies, from this table: 
                # http://wise2.ipac.caltech.edu/docs/release/allsky/expsup/sec2_3f.html
                img1, noise1 = load_image(objname,contours[0]), None
                img1, noise1 = img1*1.9350E-06, noise1*(1.9350e-06)**-2
                img2, noise2 = load_image(objname,contours[1]), None
                img2, noise2 = img2*2.7048E-06, noise2*(2.7048E-06)**-2

                ### translate object location into pixels using WCS coordinates
                wcs = WCS(img1.header)
                pix_center = wcs.all_world2pix([[ra[0],dec[0]]],1)
            else:
                img1, noise1 = load_wise_data(objname,contours[0].split(' ')[1])
                img2, noise2 = load_wise_data(objname,contours[1].split(' ')[1])

                ### translate object location into pixels using WCS coordinates
                wcs = WCS(img1.header)
                pix_center = wcs.all_world2pix([[ra[0],dec[0]]],1)

                if (pix_center.squeeze()[0]-4 > img1.shape[1]) or \
                    (pix_center.squeeze()[1]-4 > img1.shape[0]) or \
                    (np.any(pix_center < 4)):
                    print 'object not in image, checking for additional image'
                    print pix_center, img1.shape
                    img1, noise1 = load_wise_data(objname,contours[0].split(' ')[1],load_other = True)
                    img2, noise2 = load_wise_data(objname,contours[1].split(' ')[1],load_other = True)

                    wcs = WCS(img1.header)
                    pix_center = wcs.all_world2pix([[ra[0],dec[0]]],1)
                    print pix_center, img1.shape
        except:
            gradient.append(None)
            gradient_error.append(None)
            arcsec.append(None)
            kpc.append(None)
            objname_out.append(None)
            continue

        size = calc_dist(wcs, pix_center, phot_size, img1.data.shape)

        ### convert inverse variance to noise
        noise1.data = (1./noise1.data)**0.5
        noise2.data = (1./noise2.data)**0.5

        ### build image extents
        extent = image_extent(size,pix_center,wcs)

        ### convolve W1 to W2 resolution
        w1_convolved, kernel = match_resolution(img1.data,contours[0],contours[1],
                                                    kernel=kernel,data1_res=px_scale)
        w1_convolved_noise, kernel = match_resolution(noise1.data,contours[0],contours[1],
                                                      kernel=kernel,data1_res=px_scale)

        #### put onto same scale, and grab slices
        data2, footprint = reproject_exact(img2, img1.header)
        noise2, footprint = reproject_exact(noise2, img1.header)

        img1_slice = w1_convolved[size[2]:size[3],size[0]:size[1]]
        img2_slice = data2[size[2]:size[3],size[0]:size[1]]
        noise1_slice = w1_convolved_noise[size[2]:size[3],size[0]:size[1]]
        noise2_slice = noise2[size[2]:size[3],size[0]:size[1]]

        ### subtract background from both images
        # identify background pixels.
        # background is any pixel consistent within X sigma of background!
        sigma = 3.0
        if paperplot:
            sigma = 5.0
        mean1, median1, std1 = sigma_clipped_stats(w1_convolved, sigma=sigma,iters=10)
        background1 = img1_slice < (median1+std1) 
        img1_slice -= median1
        mean2, median2, std2 = sigma_clipped_stats(data2, sigma=sigma, iters=10)
        background2 = img2_slice < (median2+std2)
        img2_slice -= median2

        #### calculate the color
        flux_color  = convert_to_color(img1_slice, img2_slice,None,None,contours[0],contours[1],
                                       minflux=-np.inf, vega_conversions=brown_data)

        ### don't show any "background" pixels!
        background = background1 | background2
        flux_color[background] = np.nan

        ### plot colormap
        count = 0
        if paperplot:
            if fig is None:
                fig, axall = plt.subplots(1,2, figsize=(12,6))
                fig.subplots_adjust(right=0.8,wspace=0.4,hspace=0.3,left=0.12)
                cb_ax = fig.add_axes([0.83, 0.15, 0.05, 0.7])
                ax = np.ravel(axall[0])
            else:
                ax = np.ravel(axall[1])
                count = 1
            vmin, vmax = -0.4,1.05
        else:
            fig, ax = plt.subplots(1,2, figsize=(12,6))
            vmin, vmax = color_limits[0], color_limits[1]
        ax = np.ravel(ax)
        img = ax[0].imshow(flux_color, origin='lower',extent=extent,vmin=vmin,vmax=vmax,cmap='plasma')

        if not paperplot:
            cbar = fig.colorbar(img, ax=ax[0])
        elif count == 0:
            cbar = fig.colorbar(img, cax=cb_ax)
            cbar.set_label('(W1-W2) [Vega]', fontdict={'fontsize':18})

        ax[0].set_xlabel(r'$\Delta$(arcsec)')
        ax[0].set_ylabel(r'$\Delta$(arcsec)')

        ### plot W1 contours
        if not paperplot:
            plot_contour(ax[0],np.log10(img2_slice),ncontours=20)

        ### find image center in W2 image, and mark it
        # do this by finding the source closest to center
        tbl = []
        nthresh, box_size = 20, 4
        fake_noise2_error = copy.copy(noise2_slice)
        bad = np.logical_or(np.isinf(noise2_slice),np.isnan(noise2_slice))
        fake_noise2_error[bad] = fake_noise2_error[~bad].max()
        while len(tbl) < 1:
            threshold = nthresh * std1 # peak threshold, @ 20 sigma
            tbl = find_peaks(img2_slice, threshold, box_size=box_size, subpixel=True, border_width=3, error = fake_noise2_error)
            nthresh -=2
        
            if nthresh < 2:
                nthresh = 20
                box_size += 1

        '''
        center = np.array(img2_slice.shape)/2.
        idxmax = ((center[0]-tbl['x_peak'])**2 + (center[1]-tbl['y_peak'])**2).argmin()
        fig, ax = plt.subplots(1,1, figsize=(6,6))
        ax.plot(tbl['x_peak'][idxmax],tbl['y_peak'][idxmax],'x',color='red',ms=10)
        ax.imshow(img2_slice,origin='lower')
        plot_contour(ax, np.log10(img2_slice),ncontours=20)
        plt.show()
        '''


        ### find size of biggest one
        imgcenter = np.array(img2_slice.shape)/2.
        idxmax = ((imgcenter[0]-tbl['x_centroid'])**2 + (imgcenter[1]-tbl['y_centroid'])**2).argmin()
        center = [tbl['x_centroid'][idxmax], tbl['y_centroid'][idxmax]]

        ### find center in arcseconds (NEW)
        center_coordinates = SkyCoord.from_pixel(imgcenter[0],imgcenter[1],wcs)
        x_pos_obj = SkyCoord.from_pixel(center[0],imgcenter[1],wcs)
        y_pos_obj = SkyCoord.from_pixel(imgcenter[0],center[1],wcs)
        xarcsec = x_pos_obj.separation(center_coordinates).arcsec
        if center[0] < imgcenter[0]:
            xarcsec = -xarcsec
        yarcsec = y_pos_obj.separation(center_coordinates).arcsec
        if center[1] < imgcenter[1]:
            yarcsec = -yarcsec

        #xarcsec = (extent[1]-extent[0])*center[0]/float(img2_slice.shape[0]) + extent[0]
        #yarcsec = (extent[3]-extent[2])*center[1]/float(img2_slice.shape[1]) + extent[2]
        ax[0].scatter(xarcsec,yarcsec,color='black',marker='x',s=50,linewidth=2)

        ### add in WISE PSF
        wise_psf = 6 # in arcseconds
        start = 0.85*extent[0]

        if not paperplot:
            ax[0].plot([start,start+wise_psf],[start,start],lw=2,color='k')
            ax[0].text(start+wise_psf/2.,start+1, '6"', ha='center')
            ax[0].set_xlim(extent[0],extent[1]) # reset plot limits b/c of text stuff
            ax[0].set_ylim(extent[2],extent[3])
        else:
            ax[0].set_xlim(-65,65)
            ax[0].set_ylim(-65,65)

        ### gradient
        phys_scale = float(1./WMAP9.arcsec_per_kpc_proper(pdata['z'][idx]).value)
        if objname == 'CGCG 436-030':
            center[1] = center[1]+1.5
            yarcsec += px_scale*1.5

        grad, graderr, x_arcsec, back = measure_gradient(img1_slice,img2_slice, 
                                  noise1_slice, noise2_slice, background,
                                  ax, center,
                                  tbl['peak_value'][idxmax], (xarcsec,yarcsec),
                                  phys_scale,paperplot=paperplot)

        obj_size_phys = phot_size*phys_scale
        if not paperplot:
            ax[1].text(0.05,0.06,r'f$_{\mathrm{AGN,MIR}}$='+"{:.2f}".format(pdata['pars']['fagn']['q50'][idx])+\
                                    ' ('+"{:.2f}".format(pdata['pars']['fagn']['q84'][idx]) +
                                    ') ('+"{:.2f}".format(pdata['pars']['fagn']['q16'][idx])+')',
                                    transform=ax[1].transAxes,color='black',fontsize=9)

            ax[1].axvline(phot_size, linestyle='--', color='0.2',lw=2,zorder=-1)

        else:
            ax[0].text(0.98,0.94,objname,transform=ax[0].transAxes,fontsize=14,weight='bold',ha='right')
            ax[0].text(0.98,0.88,r'$\nabla$(2 kpc)='+"{:.2f}".format(grad[1]),fontsize=14,transform=ax[0].transAxes,ha='right')

        gradient.append(grad)
        gradient_error.append(graderr)
        arcsec.append(x_arcsec)
        obj_size_kpc.append(obj_size_phys)
        objname_out.append(objname)
        background_out.append(back)
        print objname, back

        # I/O
        outname = outfolder+'/'+objname+'.png'
        if paperplot:
            outname = outfolder+'/sample_wise_gradient.png'

        if (not paperplot) | (count == 1):
            if not paperplot:
                plt.tight_layout()
            plt.savefig(outname,dpi=150)
            plt.close()

    out = {
            'gradient': np.array(gradient),
            'gradient_error': np.array(gradient_error),
            'arcsec': np.array(arcsec),
            'obj_size_brown_kpc': np.array(obj_size_kpc),
            'objname': objname_out,
            'background_fraction': np.array(background_out)
          }
    if not paperplot:
        pickle.dump(out,open(outfile, "wb"))
Ejemplo n.º 29
0
def pick_positions_from_map(
    catalog,
    chosen_seds,
    input_map,
    N_bins,
    Npermodel,
    outfile=None,
    refimage=None,
    refimage_hdu=1,
    wcs_origin=1,
    Nrealize=1,
    set_coord_boundary=None,
    region_from_filters=None,
):
    """
    Spreads a set of fake stars across regions of similar values,
    given a map file generated by 'create background density map' or
    'create stellar density map' in the tools directory.

    The tiles of the given map are divided across a given
    number of bins. Each bin will then have its own set of tiles,
    which constitute a region on the image.

    Then, for each bin, the given set of fake stars is duplicated,
    and the stars are assigned random positions within this region.

    This way, it can be ensured that enough ASTs are performed for each
    regime of the map, making it possible to have a separate noise model
    for each of these regions.

    Parameters
    ----------

    catalog: Observations object
        Provides the observations

    chosen_seds: astropy Table
        Table containing fake stars to be duplicated and assigned positions

    input_map: str
        Path to a hd5 file containing the file written by a DensityMap

    N_bins: int
        The number of bins for the range of background density values.
        The bins will be picked on a linear grid, ranging from the
        minimum to the maximum value of the map. Then, each tile will be
        put in a bin, so that a set of tiles of the map is obtained for
        each range of source density/background values.

    refimage: str
        Path to fits image that is used for the positions. If none is
        given, the ra and dec will be put in the x and y output columns
        instead.

    refimage_hdu: int (default=1)
        index of the HDU from which to get the header, which will be used
        to extract WCS information

    wcs_origin : 0 or 1 (default=1)
        As described in the WCS documentation: "the coordinate in the upper
        left corner of the image. In FITS and Fortran standards, this is 1.
        In Numpy and C standards this is 0."

    Nrealize: integer
        The number of times each model should be repeated for each
        background regime. This is to sample the variance due to
        variations within each region, for each individual model.

    set_coord_boundary : None, or list of 2 numpy arrays
        If provided, these RA/Dec coordinates will be used to limit the
        region over which ASTs are generated.  Input should be list of two
        arrays, the first RA and the second Dec, ordered sequentially
        around the region (either CW or CCW).  If the input catalog only has x/y
        (no RA/Dec), a refimage is required.

    region_from_filters : None, list of filter name(s), or 'all'
        If provided, ASTs will only be placed in regions with this particular
        combination of filter(s).  Or, if 'all' is chosen, ASTs will only be
        placed where there is overlap with all filters.  In practice, this
        means creating a convex hull around the catalog RA/DEC of sources with
        valid values in these filters.  Note that if the region in question is
        a donut, this will put ASTs in the hole.  This will also only work
        properly if the region is a convex polygon.  A solution to these needs
        to be figured out at some point.

    Returns
    -------
    astropy Table: List of fake stars, with magnitudes and positions
    - optionally -
    ascii file of this table, written to outfile

    """

    # if refimage exists, extract WCS info
    if refimage is None:
        wcs = None
    else:
        with fits.open(refimage) as hdu:
            imagehdu = hdu[refimage_hdu]
            wcs = WCS(imagehdu.header)

    # if appropriate information is given, extract the x/y positions so that
    # there are no ASTs generated outside of the catalog footprint
    colnames = catalog.data.columns
    xy_pos = False
    radec_pos = False

    # if x/y in catalog, save them
    if ("X" in colnames) or ("x" in colnames):
        xy_pos = True
        if "X" in colnames:
            x_positions = catalog.data["X"][:]
            y_positions = catalog.data["Y"][:]
        if "x" in colnames:
            x_positions = catalog.data["x"][:]
            y_positions = catalog.data["y"][:]

    # if RA/Dec in catalog, save them
    if ("RA" in colnames) or ("ra" in colnames):
        radec_pos = True
        if "RA" in colnames:
            ra_positions = catalog.data["RA"][:]
            dec_positions = catalog.data["DEC"][:]
        if "ra" in colnames:
            ra_positions = catalog.data["ra"][:]
            dec_positions = catalog.data["dec"][:]

    # if only one of those exists and there's a refimage, convert to the other
    if xy_pos and not radec_pos and refimage:
        radec_pos = True
        x_positions, y_positions = wcs.all_world2pix(
            ra_positions, dec_positions, wcs_origin
        )
    if radec_pos and not xy_pos and refimage:
        xy_pos = True
        ra_positions, dec_positions = wcs.all_pix2world(
            x_positions, y_positions, wcs_origin
        )

    # if no x/y or ra/dec in the catalog, raise error
    if not xy_pos and not radec_pos:
        raise RuntimeError(
            "Your catalog does not supply X/Y or RA/DEC information to ensure ASTs are within catalog boundary"
        )

    # create path containing the positions
    catalog_boundary_xy = None
    catalog_boundary_radec = None
    if xy_pos:
        catalog_boundary_xy = cut_catalogs.convexhull_path(x_positions, y_positions)
    if radec_pos:
        catalog_boundary_radec = cut_catalogs.convexhull_path(
            ra_positions, dec_positions
        )

    # if coord_boundary set, define an additional boundary for ASTs
    if set_coord_boundary is not None:
        # initialize variables
        coord_boundary_xy = None
        coord_boundary_radec = None
        # evaluate one or both
        if xy_pos and refimage:
            bounds_x, bounds_y = wcs.all_world2pix(
                set_coord_boundary[0], set_coord_boundary[1], wcs_origin
            )
            coord_boundary_xy = Path(np.array([bounds_x, bounds_y]).T)
        if radec_pos:
            coord_boundary_radec = Path(
                np.array([set_coord_boundary[0], set_coord_boundary[1]]).T
            )

    # if region_from_filters is set, define an additional boundary for ASTs
    if region_from_filters is not None:
        # need catalog file from datamodel
        importlib.reload(datamodel)

        # 1. find the sub-list of sources
        if isinstance(region_from_filters, list):
            # good stars with user-defined partial overlap
            _, good_stars = cut_catalogs.cut_catalogs(
                datamodel.obsfile,
                "N/A",
                flagged=True,
                flag_filter=region_from_filters,
                no_write=True,
            )
        elif region_from_filters == "all":
            # good stars only with fully overlapping region
            _, good_stars = cut_catalogs.cut_catalogs(
                datamodel.obsfile, "N/A", partial_overlap=True, no_write=True
            )
        else:
            raise RuntimeError("Invalid argument for region_from_filters")

        # 2. define the Path object for the convex hull
        # initialize variables
        filt_reg_boundary_xy = None
        filt_reg_boundary_radec = None
        # evaluate one or both
        if xy_pos:
            filt_reg_boundary_xy = cut_catalogs.convexhull_path(
                x_positions[good_stars == 1], y_positions[good_stars == 1]
            )
        if radec_pos:
            filt_reg_boundary_radec = cut_catalogs.convexhull_path(
                ra_positions[good_stars == 1], dec_positions[good_stars == 1]
            )

    # Load the background map
    print(Npermodel, " repeats of each model in each map bin")

    bdm = density_map.BinnedDensityMap.create(input_map, N_bins)
    tile_vals = bdm.tile_vals()
    max_val = np.amax(tile_vals)
    min_val = np.amin(tile_vals)
    tiles_foreach_bin = bdm.tiles_foreach_bin()

    # Remove any of the tiles that aren't contained within user-imposed
    # constraints (if any)
    if (set_coord_boundary is not None) or (region_from_filters is not None):

        tile_ra_min, tile_dec_min = bdm.min_ras_decs()
        tile_ra_delta, tile_dec_delta = bdm.delta_ras_decs()

        for i, tile_set in enumerate(tiles_foreach_bin):

            # keep track of which indices to discard
            keep_tile = np.ones(len(tile_set), dtype=bool)

            for j, tile in enumerate(tile_set):

                # corners of the tile
                ra_min = tile_ra_min[tile]
                ra_max = tile_ra_min[tile] + tile_ra_delta[tile]
                dec_min = tile_dec_min[tile]
                dec_max = tile_dec_min[tile] + tile_dec_delta[tile]

                # make a box object for the tile
                tile_box_radec = box(ra_min, dec_min, ra_max, dec_max)
                tile_box_xy = None
                if refimage:
                    bounds_x, bounds_y = wcs.all_world2pix(
                        np.array([ra_min, ra_max]),
                        np.array([dec_min, dec_max]),
                        wcs_origin,
                    )

                    tile_box_xy = box(
                        np.min(bounds_x),
                        np.min(bounds_y),
                        np.max(bounds_x),
                        np.max(bounds_y),
                    )

                # discard tile if there's no overlap with user-imposed regions

                # - set_coord_boundary
                if set_coord_boundary is not None:
                    # coord boundary is input in RA/Dec, and tiles are RA/Dec,
                    # so there's no need to check the x/y version of either
                    if (
                        Polygon(coord_boundary_radec.vertices)
                        .intersection(tile_box_radec)
                        .area
                        == 0
                    ):
                        keep_tile[j] = False

                # - region_from_filters
                if region_from_filters is not None:
                    if filt_reg_boundary_xy and tile_box_xy:
                        if (
                            Polygon(filt_reg_boundary_xy.vertices)
                            .intersection(tile_box_xy)
                            .area
                            == 0
                        ):
                            keep_tile[j] = False
                    elif filt_reg_boundary_radec and tile_box_radec:
                        if (
                            Polygon(filt_reg_boundary_radec.vertices)
                            .intersection(tile_box_radec)
                            .area
                            == 0
                        ):
                            keep_tile[j] = False
                    else:
                        warnings.warn(
                            "Unable to use regions_from_filters to remove SD/bg tiles"
                        )

            # remove anything that needs to be discarded
            tiles_foreach_bin[i] = tile_set[keep_tile]

    # Remove empty bins
    tile_sets = [tile_set for tile_set in tiles_foreach_bin if len(tile_set)]
    print(
        "{0} non-empty map bins (out of {1}) found between {2} and {3}".format(
            len(tile_sets), N_bins, min_val, max_val
        )
    )

    # Repeat the seds Nrealize times (sample each on at Nrealize
    # different positions, in each region)
    repeated_seds = np.repeat(chosen_seds, Nrealize)
    Nseds_per_region = len(repeated_seds)
    # For each set of tiles, repeat the seds and spread them evenly over
    # the tiles
    repeated_seds = np.repeat(repeated_seds, len(tile_sets))

    out_table = Table(repeated_seds, names=chosen_seds.colnames)
    ast_x_list = np.zeros(len(out_table))
    ast_y_list = np.zeros(len(out_table))
    bin_indices = np.zeros(len(out_table))

    tile_ra_min, tile_dec_min = bdm.min_ras_decs()
    tile_ra_delta, tile_dec_delta = bdm.delta_ras_decs()

    for bin_index, tile_set in enumerate(
        tqdm(
            tile_sets,
            desc="{:.2f} models per map bin".format(Nseds_per_region / Npermodel),
        )
    ):
        start = bin_index * Nseds_per_region
        stop = start + Nseds_per_region
        bin_indices[start:stop] = bin_index
        for i in range(Nseds_per_region):

            # keep track of whether we're still looking for valid coordinates
            x = None
            y = None

            while (x is None) or (y is None):
                # Pick a random tile in this tile set
                tile = np.random.choice(tile_set)
                # Within this tile, pick a random ra and dec
                ra = tile_ra_min[tile] + np.random.random_sample() * tile_ra_delta[tile]
                dec = (
                    tile_dec_min[tile]
                    + np.random.random_sample() * tile_dec_delta[tile]
                )

                # if we can't convert this to x/y, do everything in RA/Dec
                if wcs is None:
                    x, y = ra, dec

                    # check that this x/y is within the catalog footprint
                    if catalog_boundary_radec:
                        # N,2 array of AST X and Y positions
                        inbounds = catalog_boundary_radec.contains_points([[x, y]])[0]

                        if not inbounds:
                            x = None

                    # check that this x/y is with any input boundary
                    if set_coord_boundary is not None:
                        if coord_boundary_radec:
                            inbounds = coord_boundary_radec.contains_points([[x, y]])[0]
                            if not inbounds:
                                x = None
                    if region_from_filters is not None:
                        if filt_reg_boundary_radec:
                            # fmt: off
                            inbounds = filt_reg_boundary_radec.contains_points([[x, y]])[0]
                            # fmt: on
                            if not inbounds:
                                x = None

                # if we can convert to x/y, do everything in x/y
                else:
                    [x], [y] = wcs.all_world2pix(
                        np.array([ra]), np.array([dec]), wcs_origin
                    )

                    # check that this x/y is within the catalog footprint
                    # N,2 array of AST X and Y positions
                    inbounds = catalog_boundary_xy.contains_points([[x, y]])[0]
                    if not inbounds:
                        x = None

                    # check that this x/y is with any input boundary
                    if set_coord_boundary is not None:
                        if coord_boundary_xy:
                            inbounds = coord_boundary_xy.contains_points([[x, y]])[0]
                            if not inbounds:
                                x = None
                    if region_from_filters is not None:
                        if filt_reg_boundary_xy:
                            inbounds = filt_reg_boundary_xy.contains_points([[x, y]])[0]
                            if not inbounds:
                                x = None

            j = bin_index * Nseds_per_region + i
            ast_x_list[j] = x
            ast_y_list[j] = y

    # I'm just mimicking the format that is produced by the examples
    cs = []
    cs.append(Column(np.zeros(len(out_table), dtype=int), name="zeros"))
    cs.append(Column(np.ones(len(out_table), dtype=int), name="ones"))

    # positions were found using RA/Dec
    if wcs is None:
        cs.append(Column(ast_x_list, name="RA"))
        cs.append(Column(ast_y_list, name="DEC"))
    # positions were found using x/y
    else:
        cs.append(Column(ast_x_list, name="X"))
        cs.append(Column(ast_y_list, name="Y"))

    for i, c in enumerate(cs):
        out_table.add_column(c, index=i)  # insert these columns from the left

    # Write out the table in ascii
    if outfile:
        formats = {k: "%.5f" for k in out_table.colnames[2:]}
        ascii.write(out_table, outfile, overwrite=True, formats=formats)

    return out_table
Ejemplo n.º 30
0
def get_img_offsets(imgList, subPixel=False, mode='wcs'):
    """A function to compute the offsets between images using either the WCS
    values contained in each image header or using cross-correlation techniques
    with an emphasis on star alignment for sub-pixel accuracy.

    parameters:
    imgList  -- the list of images to be aligned.
    subPixel -- this boolean flag determines whether to round image offsets to
                the nearest integer value.
    mode     -- ['wcs' | 'cross_correlate'] the method to be used for
                aligning the images in imgList. 'wcs' uses the astrometry
                in the header while 'cross_correlate' selects a reference
                image and computes image offsets using cross-correlation.
    """
    # Catch the case where a list of images was not passed
    if not isinstance(imgList, list):
        raise ValueError('imgList variable must be a list of images')

    # Catch the case where imgList has only one image
    if len(imgList) <= 1:
        print('Must have more than one image in the list to be aligned')
        return (0, 0)

    # Catch the case where imgList has only two images
    if len(imgList) == 2:
        return imgList[0].get_img_offsets(imgList[1],
            subPixel=subPixel, mode=mode)

    #**********************************************************************
    # Get the offsets using whatever mode was selected
    #**********************************************************************
    if mode.lower() == 'wcs':
        # Compute the relative position of each of the images in the stack
        wcs1      = WCS(imgList[0].header)
        x1, y1    = imgList[0].arr.shape[1]//2, imgList[0].arr.shape[0]//2

        # Append the first image coordinates to the list
        shapeList = [imgList[0].arr.shape]
        imgXpos   = [float(x1)]
        imgYpos   = [float(y1)]

        # Convert pixels to sky coordinates
        skyCoord1 = pixel_to_skycoord(x1, y1, wcs1,
            origin=0, mode='wcs', cls=None)

        # Loop through all the remaining images in the list
        # Grab the WCS of the alignment image and convert back to pixels
        for img in imgList[1:]:
            wcs2   = WCS(img.header)
            x2, y2 = wcs2.all_world2pix(skyCoord1.ra, skyCoord1.dec, 0)
            shapeList.append(img.arr.shape)
            imgXpos.append(float(x2))
            imgYpos.append(float(y2))

    elif mode.lower() == 'cross_correlate':
        # Begin by selecting a reference image.
        # This should be the image with the BROADEST PSF. To determine this,
        # Let's grab the PSFparams of all the images and store the geometric
        # mean of sx, sy the best-fit Gaussian eigen-values.
        PSFsize = []
        for img in imgList:
            PSFparams, _ = img.get_psf()
            PSFsize.append(np.sqrt(PSFparams['sx']*PSFparams['sy']))

        # Use the first image in the list as the "reference image"
        refInd    = np.int((np.where(PSFsize == np.max(PSFsize)))[0])
        otherInds = (np.where(PSFsize != np.max(PSFsize)))[0]
        refImg    = imgList[refInd]

        # Initalize empty lists for storing offsets and shapes
        shapeList = []
        imgXpos   = []
        imgYpos   = []

        # Loop through the rest of the images.
        # Use cross-correlation to get relative offsets,
        # and accumulate image shapes
        for img in imgList:
            if img is refImg:
                # Just append null values for the reference image
                shapeList.append(refImg.arr.shape)
                imgXpos.append(0.0)
                imgYpos.append(0.0)
            else:
                # Compute actual image offset between reference and image
                dx, dy = refImg.get_img_offsets(img,
                    mode='cross_correlate',
                    subPixel=subPixel)

                # Append cross_correlation values for non-reference image
                shapeList.append(img.arr.shape)
                imgXpos.append(dx)
                imgYpos.append(dy)
    else:
        raise ValueError('Mode not recognized')

    # Center the image offsets about the median vector
    # Compute the median pointing
    x1 = np.median(imgXpos)
    y1 = np.median(imgYpos)

    # Compute the relative pointings from the median position
    dx = x1 - np.array(imgXpos)
    dy = y1 - np.array(imgYpos)

    # Compute the each distance from the median pointing
    imgDist   = np.sqrt(dx**2.0 + dy**2.0)
    centerImg = np.where(imgDist == np.min(imgDist))[0][0]

    # Set the "reference image" to the one closest to the median pointing
    x1, y1 = imgXpos[centerImg], imgYpos[centerImg]

    # Recompute the offsets from the reference image
    # (add an 'epsilon' shift to make sure ALL images get shifted
    # at least a tiny bit... this guarantees the images all get convolved
    # by the pixel shape.)
    dx = x1 - np.array(imgXpos)
    dy = y1 - np.array(imgYpos)

    # Return the image offsets
    return (dx, dy)