Ejemplo n.º 1
0
    def _combine_exclude_mask(self, mask):
        # create masks from exclude/include regions and combine it with the
        # input DQ mask:
        #
        regmask = None
        if self.src_find_filters is not None and \
           'region_file' in self.src_find_filters:
            reg_file_name = self.src_find_filters['region_file']
            if not os.path.isfile(reg_file_name):
                raise IOError(
                    "The 'exclude' region file '{:s}' does not exist.".format(
                        reg_file_name))
        else:
            return mask

        # get data image size:
        (img_ny, img_nx) = self.source.shape

        # find out if user provided a region file or a mask FITS file:
        reg_file_ext = os.path.splitext(reg_file_name)[-1]
        if reg_file_ext.lower().strip() in ['.fits', '.fit'] and \
           basicFITScheck(reg_file_name):
            # likely we are dealing with a FITS file.
            # check that the file is a simple with 2 axes:
            hdulist = fits.open(reg_file_name, memmap=False)
            extlist = get_extver_list(hdulist, extname=None)
            for ext in extlist:
                usermask = hdulist[ext].data
                if usermask.shape == (img_ny, img_nx):
                    regmask = usermask.astype(np.bool)
                    break
            hdulist.close()
            if regmask is None:
                raise ValueError("None of the image-like extensions in the "
                                 "user-provided exclusion mask '{}' has a "
                                 "correct shape".format(reg_file_name))

        else:
            # we are dealing with a region file:
            reglist = pyregion.open(reg_file_name)

            ## check that regions are in image-like coordinates:
            ##TODO: remove the code below once 'pyregion' package can correctly
            ##      (DS9-like) convert sky coordinates to image coordinates for all
            ##      supported shapes.
            #if not all([ (x.coord_format == 'image' or \
            #              x.coord_format == 'physical') for x in reglist]):
            #    print("WARNING: Some exclusion regions are in sky coordinates.\n"
            #          "         These regions will be ignored.")
            #    # filter out regions in sky coordinates:
            #    reglist = pyregion.ShapeList(
            #        [x for x in reglist if x.coord_format == 'image' or \
            #         x.coord_format == 'physical']
            #    )

            #TODO: comment out next lines if we do not support region files
            #      in sky coordinates and uncomment previous block:
            # Convert regions from sky coordinates to image coordinates:
            auxwcs = _AuxSTWCS(self.wcs)
            reglist = reglist.as_imagecoord(auxwcs, rot_wrt_axis=2)

            # if all regions are exclude regions, then assume that the entire image
            # should be included and that exclude regions exclude from this
            # rectangular region representing the entire image:
            if all([x.exclude for x in reglist]):
                # we slightly widen the box to make sure that
                # the entire image is covered:
                imreg = pyregion.parse(
                    "image;box({:.1f},{:.1f},{:d},{:d},0)".format(
                        (img_nx + 1) / 2.0, (img_ny + 1) / 2.0, img_nx + 1,
                        img_ny + 1))

                reglist = pyregion.ShapeList(imreg + reglist)

            # create a mask from regions:
            regmask = np.asarray(reglist.get_mask(shape=(img_ny, img_nx)),
                                 dtype=np.bool)

        if mask is not None and regmask is not None:
            mask = np.logical_and(regmask, mask)
        else:
            mask = regmask

        #DEBUG:
        if mask is not None:
            fn = os.path.splitext(self.fname)[0] + '_srcfind_mask.fits'
            fits.writeto(fn, mask.astype(dtype=np.uint8), overwrite=True)

        return mask
Ejemplo n.º 2
0
def map_region_files(input_reg,
                     images,
                     img_wcs_ext='sci',
                     refimg=None,
                     ref_wcs_ext='sci',
                     chip_reg=None,
                     outpath='./regions',
                     filter=None,
                     catfname=None,
                     iteractive=False,
                     append=False,
                     verbose=True):
    # Check that output directory exists:
    if outpath in [None, ""]:
        outpath = os.path.curdir + os.path.sep
    elif not os.path.isdir(outpath):
        raise IOError("The output directory \'%s\' does not exist." % outpath)

    if filter is not None and filter.lower() not in ["fast", "precise"]:
        raise TypeError("The 'filter' argument can be None, 'fast', " \
                        "or 'precise'.")

    #TODO: Implement the "precise" checking of region intersection
    # with the image's bounding box.
    if filter is not None and filter.lower() == "precise":
        _print_warning("\"precise\" filter option is not yet supported. "\
                       "Reverting to filter = \"fast\" instead.")
        filter = "fast"

    # create a 2D list of region - header (with WCS) pairs:
    # [[reg1,ref_wcs_header1],[reg2,ref_wcs_header2],...]
    regheaders = build_reg_refwcs_header_list(input_reg, refimg, ref_wcs_ext,
                                              verbose)

    # create a list of input files *without* extensions (if present) and
    # a 2D list of "chip" region - image extension pairs:
    # [[reg1, ext1], [reg2,ext2],...]
    imgfnames, cregext = build_img_ext_reg_list(images, chip_reg, img_wcs_ext,
                                                verbose)

    # build a "master" list of regions in sky coordinates:
    all_sky_regions = []
    for p in regheaders:
        if p[1] is None:  # Already in sky WCS
            all_sky_regions += p[0]
        else:
            all_sky_regions += shapes_img2sky(
                p[0], p[1])  #TODO: implement shapes_img2sky

    all_sky_regions = pyregion.ShapeList(list(all_sky_regions))

    cattb = []
    # create a region file (with regions in image coordinates) for each
    # image extension from the input_reg and chip_reg regions
    for fname in imgfnames:
        imghdu = None
        imghdu = fits.open(fname, memmap=False)
        catreg = []
        try:
            for extp in cregext:
                ext = extp[1]

                # check that the extension is of supported type:
                if not _is_supported_hdu(imghdu[ext].header):
                    raise RuntimeError("Extension {} is of unsupported " \
                                       "type.".format(ext))

                # Remove references to non-SIP distortions from the header
                #TODO: remove this line of code as well as the remove_header_tdd
                # function once publicly available release of pyregion uses
                # all_world2pix and all_pix2world functions and does this header
                # cleanup itself.


#                remove_header_tdd(imghdu[ext].header)

####### added to pass hstwcs instead of header to pyregion
#                if isinstance(ext, str):
#                    ext = findExtname(imghdu, extname = ext, extver = 1)
#                elif isinstance(ext, tuple):
#                    ext = findExtname(imghdu, extname = ext[0], extver = ext[1])

#                wcs = stwcs.wcsutil.HSTWCS(imghdu, ext=ext)
                wcs = _AuxSTWCS(imghdu, ext=ext)
                extreg = all_sky_regions.as_imagecoord(wcs, rot_wrt_axis=2)
                #######

                #extreg = all_sky_regions.as_imagecoord(imghdu[ext].header, rot_wrt_axis=2)

                if extp[0]:  # add "fixed" regions if any
                    extreg = pyregion.ShapeList(extreg + extp[0])

                # filter out regions outside the image:
                if filter and filter.lower() == 'fast':
                    fast_filter_outer_regions(extreg,
                                              imghdu[ext].header['NAXIS1'],
                                              imghdu[ext].header['NAXIS2'],
                                              origin=1)

                if len(extreg) < 1:  # do not create empty region files
                    catreg.append('None')
                    continue

                # generate output region file name:
                extsuffix = _ext2str_suffix(ext)
                basefname, fext = os.path.splitext(os.path.basename(fname))
                regfname = basefname + extsuffix + os.extsep + "reg"
                fullregfname = os.path.join(outpath, regfname)

                catreg.append(regfname)

                # save regions to a file:
                if append and os.path.isfile(fullregfname):
                    old_extreg = pyregion.open(fullregfname)
                    extreg = pyregion.ShapeList(old_extreg + extreg)

                #TODO: pyregion.write does not work well. For now use _regwite
                # (until the bug fixes get to the pyregion project).
                #TODO: we replaced pyregion.write with our implementation
                # of the write (code from _regwrite) in the locally maintained
                # pyregion until these changes get to be implemented in the
                # publicly available release of pyregion.
                #
                _regwrite(extreg, fullregfname)
                #extreg.write(fullregfname) # <- use this instead of _regwrite
                # once the pyregion bugs are fixed.
            cattb.append([fname, catreg])
        except:
            if imghdu:
                imghdu.close()
            raise

    # create exclusions catalog file:
    if catfname:
        catfh = open(os.path.join(outpath, catfname), 'w')
        for catentry in cattb:
            catfh.write(catentry[0])  # image file name
            for reg in catentry[1]:
                catfh.write(' ' + reg)  # region file name
            catfh.write('\n')
        catfh.close()
reg_name = "test.reg"
r = stregion.open(reg_name).as_imagecoord(header=f_xray[0].header)

from stregion.mpl_helper import properties_func_default


# Use custom function for patch attribute
def fixed_color(shape, saved_attrs):

    attr_list, attr_dict = saved_attrs
    attr_dict["color"] = "red"
    kwargs = properties_func_default(shape, (attr_list, attr_dict))

    return kwargs


# select region shape with tag=="Group 1"
r1 = stregion.ShapeList([rr for rr in r if rr.attr[1].get("tag") == "Group 1"])
patch_list1, artist_list1 = r1.get_mpl_patches_texts(fixed_color)

r2 = stregion.ShapeList([rr for rr in r if rr.attr[1].get("tag") != "Group 1"])
patch_list2, artist_list2 = r2.get_mpl_patches_texts()

for p in patch_list1 + patch_list2:
    ax.add_patch(p)
for t in artist_list1 + artist_list2:
    ax.add_artist(t)

plt.show()