def runDataRef(self, patchRefList):
     catalogs = dict(readCatalog(self, patchRef) for patchRef in patchRefList)
     skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRefList[0])
     idFactory = self.makeIdFactory(patchRefList[0])
     skySeed = patchRefList[0].get(self.config.coaddName + "MergedCoaddId")
     mergeCatalogStruct = self.run(catalogs, skyInfo, idFactory, skySeed)
     self.write(patchRefList[0], mergeCatalogStruct.outputCatalog)
Exemple #2
0
 def runDataRef(self, patchRefList):
     catalogs = dict(readCatalog(self, patchRef) for patchRef in patchRefList)
     skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRefList[0])
     idFactory = self.makeIdFactory(patchRefList[0])
     skySeed = patchRefList[0].get(self.config.coaddName + "MergedCoaddId")
     mergeCatalogStruct = self.run(catalogs, skyInfo, idFactory, skySeed)
     self.write(patchRefList[0], mergeCatalogStruct.outputCatalog)
Exemple #3
0
    def mergeCatalogs(self, catalogs, patchRef):
        """!
        \brief Merge multiple catalogs.

        After ordering the catalogs and filters in priority order,
        \ref getMergedSourceCatalog of the \ref FootprintMergeList_ "FootprintMergeList" created by
        \ref \_\_init\_\_ is used to perform the actual merging. Finally, \ref cullPeaks is used to remove
        garbage peaks detected around bright objects.

        \param[in]  catalogs
        \param[in]  patchRef
        \param[out] mergedList
        """

        # print("test")

        # Convert distance to tract coordinate
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)
        tractWcs = skyInfo.wcs
        peakDistance = self.config.minNewPeak / tractWcs.getPixelScale().asArcseconds()
        samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale().asArcseconds()

        # Put catalogs, filters in priority order
        orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()]
        orderedBands = [getShortFilterName(band) for band in self.config.priorityList
                        if band in catalogs.keys()]

        mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance,
                                                        self.schema, self.makeIdFactory(patchRef),
                                                        samePeakDistance)

        #
        # Add extra sources that correspond to blank sky
        #
        skySeed = patchRef.get(self.config.coaddName + "MergedCoaddId")
        skySourceFootprints = self.getSkySourceFootprints(mergedList, skyInfo, skySeed)
        if skySourceFootprints:
            key = mergedList.schema.find("merge_footprint_%s" % self.config.skyFilterName).key
            for foot in skySourceFootprints:
                s = mergedList.addNew()
                s.setFootprint(foot)
                s.set(key, True)

        # pick only the first nObjects merged sources (default:100)
        self.log.info("DEBUGGING: Keep {} sources".format(self.config.nObjects))
        mergedList = mergedList[:self.config.nObjects]

        # Sort Peaks from brightest to faintest
        for record in mergedList:
            record.getFootprint().sortPeaks()
        self.log.info("Merged to %d sources" % len(mergedList))
        # Attempt to remove garbage peaks
        self.cullPeaks(mergedList)
        return mergedList
Exemple #4
0
 def runDataRef(self, patchRefList):
     catalogs = dict(
         readCatalog(self, patchRef) for patchRef in patchRefList)
     skyInfo = getSkyInfo(coaddName=self.config.coaddName,
                          patchRef=patchRefList[0])
     idFactory = self.makeIdFactory(patchRefList[0])
     skySeed = getGen3CoaddExposureId(patchRefList[0],
                                      coaddName=self.config.coaddName,
                                      includeBand=False,
                                      log=self.log)
     mergeCatalogStruct = self.run(catalogs, skyInfo, idFactory, skySeed)
     self.write(patchRefList[0], mergeCatalogStruct.outputCatalog)
    def run(self, patchRef):
        """Measure and deblend"""

        exposure = patchRef.get(self.config.coaddName + "Coadd",
                                immediate=True)
        """Read in the FAKE mask plane"""
        mask = exposure.getMaskedImage().getMask()
        fakebit = mask.getPlaneBitMask('FAKE')

        sources = self.readSources(patchRef)
        self.log.info("Found %d sources" % len(sources))
        """ignore objects whose footprints do NOT overlap with the 'FAKE' mask"""
        removes = []
        for i_ss, ss in enumerate(sources):
            foot = ss.getFootprint()
            footTmp = afwDetect.Footprint(foot)
            footTmp.intersectMask(mask, fakebit)
            if footTmp.getArea() == foot.getArea():
                removes.append(i_ss)
        removes = sorted(removes, reverse=True)
        for r in removes:
            del sources[r]
        self.log.info("Found %d sources near fake footprints" % len(sources))

        if self.config.doDeblend:
            self.deblend.run(exposure, sources, exposure.getPsf())

            bigKey = sources.schema["deblend.parent-too-big"].asKey()
            numBig = sum(
                (s.get(bigKey) for s in sources
                 ))  # catalog is non-contiguous so can't extract column
            if numBig > 0:
                self.log.warn(
                    "Patch %s contains %d large footprints that were not deblended"
                    % (patchRef.dataId, numBig))
        self.measurement.run(exposure, sources)
        skyInfo = getSkyInfo(coaddName=self.config.coaddName,
                             patchRef=patchRef)
        self.setPrimaryFlags.run(sources,
                                 skyInfo.skyMap,
                                 skyInfo.tractInfo,
                                 skyInfo.patchInfo,
                                 includeDeblend=self.config.doDeblend)
        self.propagateFlags.run(patchRef.getButler(), sources,
                                self.propagateFlags.getCcdInputs(exposure),
                                exposure.getWcs())
        if self.config.doMatchSources:
            self.writeMatches(patchRef, exposure, sources)

        self.write(patchRef, sources)
    def run(self, patchRef):
        """Measure and deblend"""

        exposure = patchRef.get(self.config.coaddName + "Coadd", immediate=True)

        """Read in the FAKE mask plane"""
        mask = exposure.getMaskedImage().getMask()
        fakebit = mask.getPlaneBitMask('FAKE')

        sources = self.readSources(patchRef)
        self.log.info("Found %d sources"% len(sources))
        """ignore objects whose footprints do NOT overlap with the 'FAKE' mask"""
        removes = []
        for i_ss, ss in enumerate(sources):
            foot = ss.getFootprint()
            footTmp = afwDetect.Footprint(foot)
            footTmp.intersectMask(mask, fakebit)
            if footTmp.getArea() == foot.getArea():
                removes.append(i_ss)
        removes = sorted(removes, reverse=True)
        for r in removes:
            del sources[r]
        self.log.info("Found %d sources near fake footprints"% len(sources))

        if self.config.doDeblend:
            self.deblend.run(exposure, sources, exposure.getPsf())

            bigKey = sources.schema["deblend.parent-too-big"].asKey()
            numBig = sum((s.get(bigKey) for s in sources)) # catalog is non-contiguous so can't extract column
            if numBig > 0:
                self.log.warn("Patch %s contains %d large footprints that were not deblended" %
                              (patchRef.dataId, numBig))
        self.measurement.run(exposure, sources)
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)
        self.setPrimaryFlags.run(sources, skyInfo.skyMap, skyInfo.tractInfo, skyInfo.patchInfo,
                                 includeDeblend=self.config.doDeblend)
        self.propagateFlags.run(patchRef.getButler(), sources, self.propagateFlags.getCcdInputs(exposure),
                                exposure.getWcs())
        if self.config.doMatchSources:
            self.writeMatches(patchRef, exposure, sources)

        self.write(patchRef, sources)
Exemple #7
0
 def run(self, patchRef):
     """Measure and deblend"""
     exposure = patchRef.get(self.config.coaddName + "Coadd", immediate=True)
     sources = self.readSources(patchRef)
     # We sort Peaks by peak value at this stage, rather than immediately after merging them,
     # because in the future we may record different peak values in different bands, and hence
     # want to sort them differently in each band.  We'll have to modify this code at that point,
     # passing a Key for the appropriate field to sortPeaks().
     for record in sources:
         record.getFootprint().sortPeaks()
     if self.config.doDeblend:
         self.deblend.run(exposure, sources, exposure.getPsf())
     self.measurement.run(exposure, sources)
     skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)
     self.setPrimaryFlags.run(sources, skyInfo.skyMap, skyInfo.tractInfo, skyInfo.patchInfo,
                              includeDeblend=self.config.doDeblend)
     self.propagateFlags.run(patchRef.getButler(), sources, self.propagateFlags.getCcdInputs(exposure),
                             exposure.getWcs())
     if self.config.doMatchSources:
         self.writeMatches(patchRef, exposure, sources)
     self.write(patchRef, sources)
Exemple #8
0
    def mergeCatalogs(self, catalogs, patchRef):
        """Merge multiple catalogs
        """

        # Convert distance to tract coordiante
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=patchRef)
        tractWcs = skyInfo.wcs
        peakDistance = self.config.minNewPeak / tractWcs.pixelScale().asArcseconds()
        samePeakDistance = self.config.maxSamePeak / tractWcs.pixelScale().asArcseconds()

        # Put catalogs, filters in priority order
        orderedCatalogs = [catalogs[band] for band in self.config.priorityList if band in catalogs.keys()]
        orderedBands = [getShortFilterName(band) for band in self.config.priorityList
                        if band in catalogs.keys()]

        mergedList = self.merged.getMergedSourceCatalog(orderedCatalogs, orderedBands, peakDistance,
                                                        self.schema, self.makeIdFactory(patchRef),
                                                        samePeakDistance)
        copySlots(orderedCatalogs[0], mergedList)
        self.log.info("Merged to %d sources" % len(mergedList))
        return mergedList
Exemple #9
0
    def run(self, dataRef):

        self.log.info("Processing %s" % (dataRef.dataId))

        # initialize
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=dataRef)
        coadd = dataRef.get(self.config.coaddName + "Coadd")

        wcs = coadd.getWcs()
        xy0 = coadd.getXY0()   

        mask = coadd.getMaskedImage().getMask()
        dim = mask.getDimensions()
        mask_array = mask.getArray()


        ra           = [0.0]*self.config.N
        dec          = [0.0]*self.config.N
        value        = [0]*self.config.N
        isPatchInner = [True]*self.config.N
        isTractInner = [True]*self.config.N


        for i in range(self.config.N):
            ra[i], dec[i], value[i], isPatchInner[i], isTractInner[i] = self.drawOnePoint(mask_array, dim, wcs, xy0, skyInfo)



        c1 = pyfits.Column(name='ra',           format='D',  array=ra)
        c2 = pyfits.Column(name='dec',          format='D',  array=dec)
        c3 = pyfits.Column(name='value',        format='J',  array=value)
        c4 = pyfits.Column(name='isPatchInner', format='L',  array=isPatchInner)
        c5 = pyfits.Column(name='isTractInner', format='L',  array=isTractInner)
        
        tbhdu = pyfits.BinTableHDU.from_columns([c1, c2, c3, c4, c5])


        tbhdu.writeto('ran.fits', clobber=True)

        return
Exemple #10
0
 def idListGenerator(self, cache, dataRefList, selectDataList=[]):
     """! Get a generator of difference images from the selectDataList
     """
     skyInfo = getSkyInfo(coaddName=self.config.coaddName,
                          patchRef=dataRefList[0])
     diffExpRefList = self.selectExposures(dataRefList[0],
                                           skyInfo,
                                           selectDataList=selectDataList)
     for dataRef in diffExpRefList:
         try:
             exp = dataRef.get(f"{self.config.coaddName}Diff_differenceExp")
             src = dataRef.get(f"{self.config.coaddName}Diff_diaSrc")
         except Exception as e:
             self.log.debug(
                 'Cannot read data for %d %d. skipping %s' %
                 (dataRef.dataId['visit'], dataRef.dataId['ccd'], e))
             continue
         data = Struct(visit=dataRef.dataId['visit'],
                       ccd=dataRef.dataId[self.config.ccdKey],
                       filter=dataRef.dataId['filter'],
                       exp=exp,
                       src=src)
         yield data
Exemple #11
0
    def run(self, dataRef):

        self.log.info("Processing %s" % (dataRef.dataId))

        # initialize
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=dataRef)
        coadd = dataRef.get(self.config.coaddName + "Coadd")

        wcs = coadd.getWcs()
        xy0 = coadd.getXY0()

        mask = coadd.getMaskedImage().getMask()
        dim = mask.getDimensions()
        mask_array = mask.getArray()

        ra = [0.0] * self.config.N
        dec = [0.0] * self.config.N
        value = [0] * self.config.N
        isPatchInner = [True] * self.config.N
        isTractInner = [True] * self.config.N

        for i in range(self.config.N):
            ra[i], dec[i], value[i], isPatchInner[i], isTractInner[
                i] = self.drawOnePoint(mask_array, dim, wcs, xy0, skyInfo)

        c1 = pyfits.Column(name='ra', format='D', array=ra)
        c2 = pyfits.Column(name='dec', format='D', array=dec)
        c3 = pyfits.Column(name='value', format='J', array=value)
        c4 = pyfits.Column(name='isPatchInner', format='L', array=isPatchInner)
        c5 = pyfits.Column(name='isTractInner', format='L', array=isTractInner)

        tbhdu = pyfits.BinTableHDU.from_columns([c1, c2, c3, c4, c5])

        tbhdu.writeto('ran.fits', clobber=True)

        return
Exemple #12
0
    def mergeCatalogs(self, catalogs, patchRef):
        """!
        \brief Merge multiple catalogs.

        After ordering the catalogs and filters in priority order,
        \ref getMergedSourceCatalog of the \ref FootprintMergeList_ "FootprintMergeList" created by
        \ref \_\_init\_\_ is used to perform the actual merging. Finally, \ref cullPeaks is used to remove
        garbage peaks detected around bright objects.

        \param[in]  catalogs
        \param[in]  patchRef
        \param[out] mergedList
        """

        # print("test")

        # Convert distance to tract coordinate
        skyInfo = getSkyInfo(coaddName=self.config.coaddName,
                             patchRef=patchRef)
        tractWcs = skyInfo.wcs
        peakDistance = self.config.minNewPeak / tractWcs.getPixelScale(
        ).asArcseconds()
        samePeakDistance = self.config.maxSamePeak / tractWcs.getPixelScale(
        ).asArcseconds()

        # Put catalogs, filters in priority order
        orderedCatalogs = [
            catalogs[band] for band in self.config.priorityList
            if band in catalogs.keys()
        ]
        orderedBands = [
            getShortFilterName(band) for band in self.config.priorityList
            if band in catalogs.keys()
        ]

        mergedList = self.merged.getMergedSourceCatalog(
            orderedCatalogs, orderedBands, peakDistance, self.schema,
            self.makeIdFactory(patchRef), samePeakDistance)

        #
        # Add extra sources that correspond to blank sky
        #
        skySeed = patchRef.get(self.config.coaddName + "MergedCoaddId")
        skySourceFootprints = self.getSkySourceFootprints(
            mergedList, skyInfo, skySeed)
        if skySourceFootprints:
            key = mergedList.schema.find("merge_footprint_%s" %
                                         self.config.skyFilterName).key
            for foot in skySourceFootprints:
                s = mergedList.addNew()
                s.setFootprint(foot)
                s.set(key, True)

        # pick only the first nObjects merged sources (default:100)
        self.log.info("DEBUGGING: Keep {} sources".format(
            self.config.nObjects))
        mergedList = mergedList[:self.config.nObjects]

        # Sort Peaks from brightest to faintest
        for record in mergedList:
            record.getFootprint().sortPeaks()
        self.log.info("Merged to %d sources" % len(mergedList))
        # Attempt to remove garbage peaks
        self.cullPeaks(mergedList)
        return mergedList
Exemple #13
0
    def runAssociation(self, cache, dataIdList, selectDataList):
        """! Run association on a patch
        For all of the visits that overlap this patch in the band create a DIAObject
        catalog.  Only the objects in the non-overlaping area of the tract and patch
        are included.
        """
        dataRefList = [
            getDataRef(cache.butler, dataId,
                       self.config.coaddName + "Coadd_calexp")
            for dataId in dataIdList
        ]

        # We need the WCS for the patch, so we can use the first entry in the dataIdList
        dataRef = dataRefList[0]
        tract = dataRef.dataId['tract']
        skyInfo = getSkyInfo(coaddName=self.config.coaddName, patchRef=dataRef)
        skyMap = skyInfo.skyMap
        try:
            calexp = dataRef.get(f"{self.config.coaddName}Coadd_calexp")
        except Exception:
            self.log.info('Cannot read coadd data for %s' % (dataRef.dataId))
            return

        coaddWcs = calexp.getWcs()
        innerPatchBox = geom.Box2D(skyInfo.patchInfo.getInnerBBox())

        expBits = dataRef.get("deepMergedCoaddId_bits")
        expId = int(dataRef.get("deepMergedCoaddId"))
        idFactory = afwTable.IdFactory.makeSource(expId, 64 - expBits)

        if len(selectDataList) == 0:
            differenceImages = self.catalogGenerator
        else:
            differenceImages = self.idListGenerator

        initializeSelector = False
        for diffIm in differenceImages(cache, dataRefList, selectDataList):

            if initializeSelector is False:
                self.associator.initialize(diffIm.src.schema, idFactory)
                initializeSelector = True

            if len(diffIm.src) == 0:
                continue

            srcWcs = diffIm.exp.getWcs()
            isInside = np.array([
                innerPatchBox.contains(
                    coaddWcs.skyToPixel(srcWcs.pixelToSky(a.getCentroid())))
                for a in diffIm.src
            ],
                                dtype=bool)

            isGood = np.array([
                rec.getFootprint().contains(geom.Point2I(rec.getCentroid()))
                for rec in diffIm.src
            ], )

            isInnerTract = np.array([
                skyMap.findTract(srcWcs.pixelToSky(
                    a.getCentroid())).getId() == tract for a in diffIm.src
            ])

            mask = (isInside) & (isGood) & (isInnerTract)

            src = diffIm.src[mask]
            if len(src) == 0:
                continue

            self.log.info(
                'Reading difference image %d %d, %s with %d possible sources' %
                (diffIm.visit, diffIm.ccd, diffIm.filter, len(src)))

            footprints = []
            region = calexp.getBBox(afwImage.PARENT)
            for ii, rec in enumerate(src):
                # transformations on large footprints can take a long time
                # We truncate the footprint since we will rarely be interested
                # in such large footprints
                if rec.getFootprint().getArea() > self.config.maxFootprintArea:
                    spans = afwGeom.SpanSet.fromShape(
                        self.config.defaultFootprintRadius,
                        afwGeom.Stencil.CIRCLE,
                        geom.Point2I(rec.getCentroid()))
                    foot = afwDet.Footprint(spans)
                    foot.addPeak(int(rec.getX()), int(rec.getY()), 1)
                else:
                    foot = rec.getFootprint()
                footprints.append(foot.transform(srcWcs, coaddWcs, region))

            self.associator.addCatalog(src, diffIm.filter, diffIm.visit,
                                       diffIm.ccd, diffIm.calib, footprints)

        result = self.associator.finalize(idFactory)

        if len(dataRefList) > 0 and result is not None:
            dataRefList[0].put(result,
                               self.config.coaddName + 'Diff_diaObject')
            self.log.info('Total objects found %d' % len(result))

            idCatalog = self.associator.getObjectIds()
            dataRefList[0].put(idCatalog,
                               self.config.coaddName + 'Diff_diaObjectId')
    def load_model(self, model_repository=None, filter_name='g', doWarp=False):
        """Depersist a DCR model from a repository and set up the metadata.

        Parameters
        ----------
        model_repository : None, optional
            Full path to the directory of the repository to load the ``dcrCoadd`` from.
            If not set, uses the existing self.butler
        filter_name : str, optional
            Common name of the filter used. For LSST, use u, g, r, i, z, or y
        doWarp : bool, optional
            Set if the input coadds need to be warped to the reference wcs.

        Returns
        ------------------
        None, but loads self.model and sets up all the needed quantities such as the psf and bandpass objects.
        """
        self.filter_name = filter_name
        model_arr = []
        dcrCoadd_gen = self.read_exposures(datasetType="dcrCoadd", input_repository=model_repository)
        for dcrCoadd in dcrCoadd_gen:
            if doWarp:
                wrap_warpExposure(dcrCoadd, self.wcs, self.bbox)
            model_in = dcrCoadd.getMaskedImage().getImage().getArray()
            var_in = dcrCoadd.getMaskedImage().getVariance().getArray()
            mask_in = dcrCoadd.getMaskedImage().getMask().getArray()
            model_use, var_use, mask_use = _resize_image(model_in, var_in, mask_in, bbox_new=self.bbox,
                                                         bbox_old=dcrCoadd.getBBox(), expand=False)
            model_arr.append(model_use)

        self.model = model_arr
        self.n_step = len(self.model)
        # The weights should be identical for all subfilters.
        self.weights = np.zeros_like(var_use)
        nonzero_inds = var_use > 0
        self.weights[nonzero_inds] = 1./var_use[nonzero_inds]
        # self.weights = var_use*self.n_step
        # The masks should be identical for all subfilters
        self.mask = mask_use

        skyInfo = coaddBase.getSkyInfo("dcr", self.makeDataRef("dcrCoadd", subfilter=0))
        self.skyMap = skyInfo.skyMap

        self.wcs = dcrCoadd.getWcs()
        self.bbox = skyInfo.patchInfo.getInnerBBox()
        x_size, y_size = self.bbox.getDimensions()
        self.n_step = len(self.model)
        self.x_size = x_size
        self.y_size = y_size
        self.pixel_scale = self.wcs.getPixelScale()
        self.exposure_time = dcrCoadd.getInfo().getVisitInfo().getExposureTime()
        self.observatory = dcrCoadd.getInfo().getVisitInfo().getObservatory()
        bandpass_init = self.load_bandpass(filter_name=filter_name)
        wavelength_step = (bandpass_init.wavelen_max - bandpass_init.wavelen_min) / self.n_step
        self.bandpass = self.load_bandpass(filter_name=filter_name, wavelength_step=wavelength_step)
        self.bandpass_highres = self.load_bandpass(filter_name=filter_name, wavelength_step=None)

        self.psf = dcrCoadd.getPsf()
        psf_avg = self.psf.computeKernelImage().getArray()
        self.psf_size = psf_avg.shape[0]
        self.debug = False