def numpyToStack(images, center, offset):
    """Convert numpy and python objects to stack objects
    """
    cy, cx = center
    bands, height, width = images.shape
    x0, y0 = offset
    bbox = Box2I(Point2I(x0, y0), Extent2I(width, height))
    spanset = SpanSet(bbox)
    foot = Footprint(spanset)
    foot.addPeak(cx + x0, cy + y0, images[:, cy, cx].max())
    peak = foot.getPeaks()[0]
    return foot, peak, bbox
Esempio n. 2
0
 def setUp(self):
     np.random.seed(1)
     self.spans = SpanSet.fromShape(2, Stencil.CIRCLE)
     self.footprint = Footprint(self.spans)
     self.footprint.addPeak(3, 4, 10)
     self.footprint.addPeak(8, 1, 2)
     fp = Footprint(self.spans)
     for peak in self.footprint.getPeaks():
         fp.addPeak(peak["f_x"], peak["f_y"], peak["peakValue"])
     self.peaks = fp.getPeaks()
     self.bbox = self.footprint.getBBox()
     self.filters = ("G", "R", "I")
     singles = []
     images = []
     for n, f in enumerate(self.filters):
         image = ImageF(self.spans.getBBox())
         image.set(n)
         images.append(image.array)
         maskedImage = MaskedImageF(image)
         heavy = makeHeavyFootprint(self.footprint, maskedImage)
         singles.append(heavy)
     self.image = np.array(images)
     self.mFoot = MultibandFootprint(self.filters, singles)
Esempio n. 3
0
 def setUp(self):
     np.random.seed(1)
     self.spans = SpanSet.fromShape(2, Stencil.CIRCLE)
     self.footprint = Footprint(self.spans)
     self.footprint.addPeak(3, 4, 10)
     self.footprint.addPeak(8, 1, 2)
     fp = Footprint(self.spans)
     for peak in self.footprint.getPeaks():
         fp.addPeak(peak["f_x"], peak["f_y"], peak["peakValue"])
     self.peaks = fp.getPeaks()
     self.bbox = self.footprint.getBBox()
     self.filters = ("G", "R", "I")
     singles = []
     images = []
     for n, f in enumerate(self.filters):
         image = ImageF(self.spans.getBBox())
         image.set(n)
         images.append(image.array)
         maskedImage = MaskedImageF(image)
         heavy = makeHeavyFootprint(self.footprint, maskedImage)
         singles.append(heavy)
     self.image = np.array(images)
     self.mFoot = MultibandFootprint(self.filters, singles)
    def test_deblend_task(self):
        # Set the random seed so that the noise field is unaffected
        np.random.seed(0)
        shape = (5, 100, 115)
        coords = [
            # blend
            (15, 25),
            (10, 30),
            (17, 38),
            # isolated source
            (85, 90),
        ]
        amplitudes = [
            # blend
            80,
            60,
            90,
            # isolated source
            20,
        ]
        result = initData(shape, coords, amplitudes)
        targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result
        B, Ny, Nx = shape

        # Add some noise, otherwise the task will blow up due to
        # zero variance
        noise = 10 * (np.random.rand(*images.shape).astype(np.float32) - .5)
        images += noise

        filters = "grizy"
        _images = afwImage.MultibandMaskedImage.fromArrays(
            filters, images.astype(np.float32), None, noise)
        coadds = [
            afwImage.Exposure(img, dtype=img.image.array.dtype)
            for img in _images
        ]
        coadds = afwImage.MultibandExposure.fromExposures(filters, coadds)
        for b, coadd in enumerate(coadds):
            coadd.setPsf(psfs[b])

        schema = SourceCatalog.Table.makeMinimalSchema()

        detectionTask = SourceDetectionTask(schema=schema)

        # Adjust config options to test skipping parents
        config = ScarletDeblendTask.ConfigClass()
        config.maxIter = 100
        config.maxFootprintArea = 1000
        config.maxNumberOfPeaks = 4
        deblendTask = ScarletDeblendTask(schema=schema, config=config)

        table = SourceCatalog.Table.make(schema)
        detectionResult = detectionTask.run(table, coadds["r"])
        catalog = detectionResult.sources

        # Add a footprint that is too large
        src = catalog.addNew()
        halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1))
        ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50))
        bigfoot = Footprint(ss)
        bigfoot.addPeak(50, 50, 100)
        src.setFootprint(bigfoot)

        # Add a footprint with too many peaks
        src = catalog.addNew()
        ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20))
        denseFoot = Footprint(ss)
        for n in range(config.maxNumberOfPeaks + 1):
            denseFoot.addPeak(70 + 2 * n, 15 + 2 * n, 10 * n)
        src.setFootprint(denseFoot)

        # Run the deblender
        result = deblendTask.run(coadds, catalog)

        # Make sure that the catalogs have the same sources in all bands,
        # and check that band-independent columns are equal
        bandIndependentColumns = [
            "id",
            "parent",
            "deblend_nPeaks",
            "deblend_nChild",
            "deblend_peak_center_x",
            "deblend_peak_center_y",
            "deblend_runtime",
            "deblend_iterations",
            "deblend_logL",
            "deblend_spectrumInitFlag",
            "deblend_blendConvergenceFailedFlag",
        ]
        self.assertEqual(len(filters), len(result))
        ref = result[filters[0]]
        for f in filters[1:]:
            for col in bandIndependentColumns:
                np.testing.assert_array_equal(result[f][col], ref[col])

        # Check that other columns are consistent
        for f, _catalog in result.items():
            parents = _catalog[_catalog["parent"] == 0]
            # Check that the number of deblended children is consistent
            self.assertEqual(np.sum(_catalog["deblend_nChild"]),
                             len(_catalog) - len(parents))

            for parent in parents:
                children = _catalog[_catalog["parent"] == parent.get("id")]
                # Check that nChild is set correctly
                self.assertEqual(len(children), parent.get("deblend_nChild"))
                # Check that parent columns are propagated to their children
                for parentCol, childCol in config.columnInheritance.items():
                    np.testing.assert_array_equal(parent.get(parentCol),
                                                  children[childCol])

            children = _catalog[_catalog["parent"] != 0]
            for child in children:
                fp = child.getFootprint()
                img = heavyFootprintToImage(fp)
                # Check that the flux at the center is correct.
                # Note: this only works in this test image because the
                # detected peak is in the same location as the scarlet peak.
                # If the peak is shifted, the flux value will be correct
                # but deblend_peak_center is not the correct location.
                px = child.get("deblend_peak_center_x")
                py = child.get("deblend_peak_center_y")
                flux = img.image[Point2I(px, py)]
                self.assertEqual(flux, child.get("deblend_peak_instFlux"))

                # Check that the peak positions match the catalog entry
                peaks = fp.getPeaks()
                self.assertEqual(px, peaks[0].getIx())
                self.assertEqual(py, peaks[0].getIy())

            # Check that all sources have the correct number of peaks
            for src in _catalog:
                fp = src.getFootprint()
                self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks"))

            # Check that only the large foorprint was flagged as too big
            largeFootprint = np.zeros(len(_catalog), dtype=bool)
            largeFootprint[2] = True
            np.testing.assert_array_equal(largeFootprint,
                                          _catalog["deblend_parentTooBig"])

            # Check that only the dense foorprint was flagged as too dense
            denseFootprint = np.zeros(len(_catalog), dtype=bool)
            denseFootprint[3] = True
            np.testing.assert_array_equal(denseFootprint,
                                          _catalog["deblend_tooManyPeaks"])

            # Check that only the appropriate parents were skipped
            skipped = largeFootprint | denseFootprint
            np.testing.assert_array_equal(skipped, _catalog["deblend_skipped"])
Esempio n. 5
0
class MultibandFootprintTestCase(lsst.utils.tests.TestCase):
    """
    A test case for the Exposure Class
    """
    def setUp(self):
        np.random.seed(1)
        self.spans = SpanSet.fromShape(2, Stencil.CIRCLE)
        self.footprint = Footprint(self.spans)
        self.footprint.addPeak(3, 4, 10)
        self.footprint.addPeak(8, 1, 2)
        fp = Footprint(self.spans)
        for peak in self.footprint.getPeaks():
            fp.addPeak(peak["f_x"], peak["f_y"], peak["peakValue"])
        self.peaks = fp.getPeaks()
        self.bbox = self.footprint.getBBox()
        self.filters = ("G", "R", "I")
        singles = []
        images = []
        for n, f in enumerate(self.filters):
            image = ImageF(self.spans.getBBox())
            image.set(n)
            images.append(image.array)
            maskedImage = MaskedImageF(image)
            heavy = makeHeavyFootprint(self.footprint, maskedImage)
            singles.append(heavy)
        self.image = np.array(images)
        self.mFoot = MultibandFootprint(self.filters, singles)

    def tearDown(self):
        del self.spans
        del self.footprint
        del self.peaks
        del self.bbox
        del self.filters
        del self.mFoot
        del self.image

    def verifyPeaks(self, peaks1, peaks2):
        self.assertEqual(len(peaks1), len(peaks2))
        for n in range(len(peaks1)):
            pk1 = peaks1[n]
            pk2 = peaks2[n]
            # self.assertEqual(pk1["id"], pk2["id"])
            self.assertEqual(pk1["f_x"], pk2["f_x"])
            self.assertEqual(pk1["f_y"], pk2["f_y"])
            self.assertEqual(pk1["i_x"], pk2["i_x"])
            self.assertEqual(pk1["i_y"], pk2["i_y"])
            self.assertEqual(pk1["peakValue"], pk2["peakValue"])

    def testConstructor(self):
        def projectSpans(radius, value, bbox, asArray):
            ss = SpanSet.fromShape(radius, Stencil.CIRCLE, offset=(10, 10))
            image = ImageF(bbox)
            ss.setImage(image, value)
            if asArray:
                return image.array
            else:
                return image

        def runTest(images,
                    mFoot,
                    peaks=self.peaks,
                    footprintBBox=Box2I(Point2I(6, 6), Extent2I(9, 9))):
            self.assertEqual(mFoot.getBBox(), footprintBBox)
            try:
                fpImage = np.array(images)[:, 1:-1, 1:-1]
            except IndexError:
                fpImage = np.array([img.array for img in images])[:, 1:-1,
                                                                  1:-1]
            # result = mFoot.getImage(fill=0).image.array
            self.assertFloatsAlmostEqual(
                mFoot.getImage(fill=0).image.array, fpImage)
            if peaks is not None:
                self.verifyPeaks(mFoot.getPeaks(), peaks)

        bbox = Box2I(Point2I(5, 5), Extent2I(11, 11))
        xy0 = Point2I(5, 5)

        images = np.array(
            [projectSpans(n, 5 - n, bbox, True) for n in range(2, 5)])
        mFoot = MultibandFootprint.fromArrays(self.filters,
                                              images,
                                              xy0=xy0,
                                              peaks=self.peaks)
        runTest(images, mFoot)

        mFoot = MultibandFootprint.fromArrays(self.filters, images)
        runTest(images, mFoot, None, Box2I(Point2I(1, 1), Extent2I(9, 9)))

        images = [projectSpans(n, 5 - n, bbox, False) for n in range(2, 5)]
        mFoot = MultibandFootprint.fromImages(self.filters,
                                              images,
                                              peaks=self.peaks)
        runTest(images, mFoot)

        images = np.array(
            [projectSpans(n, n, bbox, True) for n in range(2, 5)])
        mFoot = MultibandFootprint.fromArrays(self.filters,
                                              images,
                                              peaks=self.peaks,
                                              xy0=bbox.getMin())
        runTest(images, mFoot)

        images = np.array(
            [projectSpans(n, 5 - n, bbox, True) for n in range(2, 5)])
        thresh = [1, 2, 2.5]
        mFoot = MultibandFootprint.fromArrays(self.filters,
                                              images,
                                              xy0=bbox.getMin(),
                                              thresh=thresh)
        footprintBBox = Box2I(Point2I(8, 8), Extent2I(5, 5))
        self.assertEqual(mFoot.getBBox(), footprintBBox)

        fpImage = np.array(images)[:, 3:-3, 3:-3]
        mask = np.all(fpImage <= np.array(thresh)[:, None, None], axis=0)
        fpImage[:, mask] = 0
        self.assertFloatsAlmostEqual(
            mFoot.getImage(fill=0).image.array, fpImage)
        img = mFoot.getImage().image.array
        img[~np.isfinite(img)] = 1.1
        self.assertFloatsAlmostEqual(mFoot.getImage(fill=1.1).image.array, img)

    def testSlicing(self):
        self.assertIsInstance(self.mFoot["R"], HeavyFootprintF)
        self.assertIsInstance(self.mFoot[:], MultibandFootprint)

        self.assertEqual(self.mFoot["I"], self.mFoot["I"])
        self.assertEqual(self.mFoot[:"I"].filters, ("G", "R"))
        self.assertEqual(self.mFoot[:"I"].getBBox(), self.bbox)
        self.assertEqual(self.mFoot[["G", "I"]].filters, ("G", "I"))
        self.assertEqual(self.mFoot[["G", "I"]].getBBox(), self.bbox)

        with self.assertRaises(TypeError):
            self.mFoot["I", 4, 5]
            self.mFoot["I", :, :]
        with self.assertRaises(IndexError):
            self.mFoot[:, :, :]

    def testSpans(self):
        self.assertEqual(self.mFoot.getSpans(), self.spans)
        for footprint in self.mFoot.singles:
            self.assertEqual(footprint.getSpans(), self.spans)

    def testPeaks(self):
        self.verifyPeaks(self.peaks, self.footprint.getPeaks())
        for footprint in self.mFoot.singles:
            self.verifyPeaks(footprint.getPeaks(), self.peaks)
Esempio n. 6
0
    def run(self, dataRef, selectDataList=[]):
        """Draw randoms for a given patch
        """

        # first test if the forced-src file exists
        # do not process if the patch doesn't exist
        try:
            dataRef.get(self.config.coaddName + "Coadd_forced_src")
        except:
            self.log.info("No forced_src file found for %s. Skipping..." % (dataRef.dataId))
            return

        # verbose
        self.log.info("Processing %s" % (dataRef.dataId))


        # create a seed that depends on patch id
        # so it is consistent among filters
        if self.config.seed == -1:
            p = [int(d) for d in dataRef.dataId["patch"].split(",") ]
            numpy.random.seed(seed=dataRef.dataId["tract"]*10000+p[0]*10+ p[1])
        else:
            numpy.random.seed(seed=self.config.seed)

        # compute sky mean and sky std_dev for this patch
        # in 2" diameter apertures (~12 pixels x 0.17"/pixel)
        # import source list for getting sky objects
        sources = dataRef.get(self.config.coaddName + "Coadd_meas")
        if True:
            sky_apertures = sources['base_CircularApertureFlux_12_0_flux'][sources['merge_peak_sky']]
            select = numpy.isfinite(sky_apertures)
            sky_mean = numpy.mean(sky_apertures[select])
            sky_std  = numpy.std(sky_apertures[select])
            # NOTE: to get 5-sigma limiting magnitudes:
            # print -2.5*numpy.log10(5.0*sky_std/coadd.getCalib().getFluxMag0()[0])
        else:
            sky_mean = 0.0
            sky_std  = 0.0

        # get coadd, coadd info and coadd psf object
        coadd = dataRef.get(self.config.coaddName + "Coadd_calexp")
        psf = coadd.getPsf()
        var = coadd.getMaskedImage().getVariance().getArray()
        skyInfo = self.getSkyInfo(dataRef)

        # wcs and reference point (wrt tract)
        # See http://hsca.ipmu.jp/hscsphinx_test/scripts/print_coord.html
        # for coordinate routines.
        wcs = coadd.getWcs()
        xy0 = coadd.getXY0()

        # dimension in pixels
        dim = coadd.getDimensions()

        # define measurement algorithms
        # mostly copied from /data1a/ana/hscPipe5/Linux64/meas_base/5.3-hsc/tests/testInputCount.py
        measureSourcesConfig = measBase.SingleFrameMeasurementConfig()
        measureSourcesConfig.plugins.names = ['base_PixelFlags', 'base_PeakCentroid', 'base_InputCount', 'base_SdssShape']
        measureSourcesConfig.slots.centroid = "base_PeakCentroid"
        measureSourcesConfig.slots.psfFlux = None
        measureSourcesConfig.slots.apFlux = None
        measureSourcesConfig.slots.modelFlux = None
        measureSourcesConfig.slots.instFlux = None
        measureSourcesConfig.slots.calibFlux = None
        measureSourcesConfig.slots.shape =  None

        # it seems it is still necessary to manually add the
        # bright-star mask flag by hand
        measureSourcesConfig.plugins['base_PixelFlags'].masksFpCenter.append("BRIGHT_OBJECT")
        measureSourcesConfig.plugins['base_PixelFlags'].masksFpAnywhere.append("BRIGHT_OBJECT")

        measureSourcesConfig.validate()

        # add PSF shape
        # sdssShape_psf = self.schema.addField("shape_sdss_psf", type="MomentsD", doc="PSF xx from SDSS algorithm", units="pixel")
        # shape_sdss_psf = self.schema.addField("shape_sdss_psf", type="MomentsD", doc="PSF yy from SDSS algorithm", units="pixel")
        # shape_sdss_psf = self.schema.addField("shape_sdss_psf", type="MomentsD", doc="PSF xy from SDSS algorithm", units="pixel")

        # additional columns

        # random number to adjust sky density
        adjust_density = self.schema.addField("adjust_density", type=float, doc="Random number between [0:1] to adjust sky density", units='')

        # sky mean and variance for the entire patch
        sky_mean_key = self.schema.addField("sky_mean", type=float, doc="Mean of sky value in 2\" diamter apertures", units='count')
        sky_std_key  = self.schema.addField("sky_std", type=float, doc="Standard deviation of sky value in 2\" diamter apertures", units='count')

        # pixel variance at random point position
        pix_variance = self.schema.addField("pix_variance", type=float, doc="Pixel variance at random point position", units="flx^2")

        # add healpix map value (if healpix map is given)
        if self.depthMap.map is not None:
            depth_key = self.schema.addField("isFullDepthColor", type="Flag", doc="True if full depth and full colors at point position", units='')

        # task and output catalog
        task = measBase.SingleFrameMeasurementTask(self.schema, config=measureSourcesConfig)
        table = afwTable.SourceTable.make(self.schema, self.makeIdFactory(dataRef))
        catalog = afwTable.SourceCatalog(table)

        if self.config.N == -1:
            # to output a constant random
            # number density, first compute
            # the area in degree
            pixel_area = coadd.getWcs().getPixelScale().asDegrees()**2
            area = pixel_area * dim[0] * dim[1]
            N = self.iround(area*self.config.Nden*60.0*60.0)
        else:
            # fixed number if random points
            N = self.config.N

        # verbose
        self.log.info("Drawing %d random points" % (N))

        # loop over N random points
        for i in range(N):
        # for i in range(100):

            # draw one random point
            x = numpy.random.random()*(dim[0]-1)
            y = numpy.random.random()*(dim[1]-1)

            # get coordinates
            radec = wcs.pixelToSky(afwGeom.Point2D(x + xy0.getX(), y + xy0.getY()))
            xy = wcs.skyToPixel(radec)

            # new record in table
            record = catalog.addNew()
            record.setCoord(radec)

            # get PSF moments and evaluate size
            #size_psf = 1.0
            #try:
            #    shape_sdss_psf_val = psf.computeShape(afwGeom.Point2D(xy))
            #except:
            #    pass
            #else:
             #   record.set(shape_sdss_psf, shape_sdss_psf_val)
             #   size_psf = shape_sdss_psf_val.getDeterminantRadius()

            # object has no footprint
            radius = 0
            spanset1 = SpanSet.fromShape(radius, stencil=Stencil.CIRCLE, offset=afwGeom.Point2I(xy))
            foot = Footprint(spanset1)
            foot.addPeak(xy[0], xy[1], 0.0)
            record.setFootprint(foot)

            # draw a number between 0 and 1 to adjust sky density
            record.set(adjust_density, numpy.random.random())

            # add sky properties
            record.set(sky_mean_key, sky_mean)
            record.set(sky_std_key, sky_std)

            # add local (pixel) variance
            record.set(pix_variance, float(var[self.iround(y), self.iround(x)]))

            # required for setPrimaryFlags
            record.set(catalog.getCentroidKey(), afwGeom.Point2D(xy))

            # add healpix map value
            if self.depthMap.map is not None:
                mapIndex = healpy.pixelfunc.ang2pix(self.depthMap.nside, numpy.pi/2.0 - radec[1].asRadians(), radec[0].asRadians(), nest=self.depthMap.nest)
                record.setFlag(depth_key, self.depthMap.map[mapIndex])

        # run measurements
        task.run(catalog, coadd)

        self.setPrimaryFlags.run(catalog, skyInfo.skyMap, skyInfo.tractInfo, skyInfo.patchInfo, includeDeblend=False)

        # write catalog
        if self.config.fileOutName == "":
            if self.config.dirOutName == "" :
                fileOutName = dataRef.get(self.config.coaddName + "Coadd_forced_src_filename")[0].replace('forced_src', 'ran')
                self.log.info("WARNING: the output file will be written in {0:s}.".format(fileOutName))
            else:
                fileOutName = "{0}/{1}/{2}/{3}/ran-{1}-{2}-{3}.fits".format(self.config.dirOutName,dataRef.dataId["filter"],dataRef.dataId["tract"],dataRef.dataId["patch"])
        else:
            fileOutName = self.config.fileOutName

        self.mkdir_p(os.path.dirname(fileOutName))
        catalog.writeFits(fileOutName)

        # to do. Define output name in init (not in paf) and
        # allow parallel processing
        # write sources
        # if self.config.doWriteSources:
        #   dataRef.put(result.sources, self.dataPrefix + 'src')

        return
Esempio n. 7
0
def modelToHeavy(source, mExposure, blend, xy0=Point2I(), dtype=np.float32):
    """Convert a scarlet model to a `MultibandFootprint`.

    Parameters
    ----------
    source : `scarlet.Component`
        The source to convert to a `HeavyFootprint`.
    mExposure : `lsst.image.MultibandExposure`
        The multiband exposure containing the image,
        mask, and variance data.
    blend : `scarlet.Blend`
        The `Blend` object that contains information about
        the observation, PSF, etc, used to convolve the
        scarlet model to the observed seeing in each band.
    xy0 : `lsst.geom.Point2I`
        `(x,y)` coordinates of the lower-left pixel of the
        entire blend.
    dtype : `numpy.dtype`
        The data type for the returned `HeavyFootprint`.

    Returns
    -------
    mHeavy : `lsst.detection.MultibandFootprint`
        The multi-band footprint containing the model for the source.
    """
    # We want to convolve the model with the observed PSF,
    # which means we need to grow the model box by the PSF to
    # account for all of the flux after convolution.
    # FYI: The `scarlet.Box` class implements the `&` operator
    # to take the intersection of two boxes.

    # Get the PSF size and radii to grow the box
    py, px = blend.observations[0].psf.get_model().shape[1:]
    dh = py // 2
    dw = px // 2
    shape = (source.bbox.shape[0], source.bbox.shape[1] + py,
             source.bbox.shape[2] + px)
    origin = (source.bbox.origin[0], source.bbox.origin[1] - dh,
              source.bbox.origin[2] - dw)
    # Create the larger box to fit the model + PSf
    bbox = Box(shape, origin=origin)
    # Only use the portion of the convolved model that fits in the image
    overlap = bbox & source.frame.bbox
    # Load the full multiband model in the larger box
    model = source.model_to_box(overlap)
    # Convolve the model with the PSF in each band
    # Always use a real space convolution to limit artifacts
    model = blend.observations[0].renderer.convolve(
        model, convolution_type="real").astype(dtype)
    # Update xy0 with the origin of the sources box
    xy0 = Point2I(overlap.origin[-1] + xy0.x, overlap.origin[-2] + xy0.y)
    # Create the spans for the footprint
    valid = np.max(np.array(model), axis=0) != 0
    valid = Mask(valid.astype(np.int32), xy0=xy0)
    spans = SpanSet.fromMask(valid)

    # Add the location of the source to the peak catalog
    peakCat = PeakCatalog(source.detectedPeak.table)
    peakCat.append(source.detectedPeak)
    # Create the MultibandHeavyFootprint
    foot = Footprint(spans)
    foot.setPeakCatalog(peakCat)
    model = MultibandImage(mExposure.filters, model, valid.getBBox())
    mHeavy = MultibandFootprint.fromImages(mExposure.filters,
                                           model,
                                           footprint=foot)
    return mHeavy
Esempio n. 8
0
class MultibandFootprintTestCase(lsst.utils.tests.TestCase):
    """
    A test case for the Exposure Class
    """

    def setUp(self):
        np.random.seed(1)
        self.spans = SpanSet.fromShape(2, Stencil.CIRCLE)
        self.footprint = Footprint(self.spans)
        self.footprint.addPeak(3, 4, 10)
        self.footprint.addPeak(8, 1, 2)
        fp = Footprint(self.spans)
        for peak in self.footprint.getPeaks():
            fp.addPeak(peak["f_x"], peak["f_y"], peak["peakValue"])
        self.peaks = fp.getPeaks()
        self.bbox = self.footprint.getBBox()
        self.filters = ("G", "R", "I")
        singles = []
        images = []
        for n, f in enumerate(self.filters):
            image = ImageF(self.spans.getBBox())
            image.set(n)
            images.append(image.array)
            maskedImage = MaskedImageF(image)
            heavy = makeHeavyFootprint(self.footprint, maskedImage)
            singles.append(heavy)
        self.image = np.array(images)
        self.mFoot = MultibandFootprint(self.filters, singles)

    def tearDown(self):
        del self.spans
        del self.footprint
        del self.peaks
        del self.bbox
        del self.filters
        del self.mFoot
        del self.image

    def verifyPeaks(self, peaks1, peaks2):
        self.assertEqual(len(peaks1), len(peaks2))
        for n in range(len(peaks1)):
            pk1 = peaks1[n]
            pk2 = peaks2[n]
            # self.assertEqual(pk1["id"], pk2["id"])
            self.assertEqual(pk1["f_x"], pk2["f_x"])
            self.assertEqual(pk1["f_y"], pk2["f_y"])
            self.assertEqual(pk1["i_x"], pk2["i_x"])
            self.assertEqual(pk1["i_y"], pk2["i_y"])
            self.assertEqual(pk1["peakValue"], pk2["peakValue"])

    def testConstructor(self):
        def projectSpans(radius, value, bbox, asArray):
            ss = SpanSet.fromShape(radius, Stencil.CIRCLE, offset=(10, 10))
            image = ImageF(bbox)
            ss.setImage(image, value)
            if asArray:
                return image.array
            else:
                return image

        def runTest(images, mFoot, peaks=self.peaks, footprintBBox=Box2I(Point2I(6, 6), Extent2I(9, 9))):
            self.assertEqual(mFoot.getBBox(), footprintBBox)
            try:
                fpImage = np.array(images)[:, 1:-1, 1:-1]
            except IndexError:
                fpImage = np.array([img.array for img in images])[:, 1:-1, 1:-1]
            # result = mFoot.getImage(fill=0).image.array
            self.assertFloatsAlmostEqual(mFoot.getImage(fill=0).image.array, fpImage)
            if peaks is not None:
                self.verifyPeaks(mFoot.getPeaks(), peaks)

        bbox = Box2I(Point2I(5, 5), Extent2I(11, 11))
        xy0 = Point2I(5, 5)

        images = np.array([projectSpans(n, 5-n, bbox, True) for n in range(2, 5)])
        mFoot = MultibandFootprint.fromArrays(self.filters, images, xy0=xy0, peaks=self.peaks)
        runTest(images, mFoot)

        mFoot = MultibandFootprint.fromArrays(self.filters, images)
        runTest(images, mFoot, None, Box2I(Point2I(1, 1), Extent2I(9, 9)))

        images = [projectSpans(n, 5-n, bbox, False) for n in range(2, 5)]
        mFoot = MultibandFootprint.fromImages(self.filters, images, peaks=self.peaks)
        runTest(images, mFoot)

        images = np.array([projectSpans(n, n, bbox, True) for n in range(2, 5)])
        mFoot = MultibandFootprint.fromArrays(self.filters, images, peaks=self.peaks, xy0=bbox.getMin())
        runTest(images, mFoot)

        images = np.array([projectSpans(n, 5-n, bbox, True) for n in range(2, 5)])
        thresh = [1, 2, 2.5]
        mFoot = MultibandFootprint.fromArrays(self.filters, images, xy0=bbox.getMin(), thresh=thresh)
        footprintBBox = Box2I(Point2I(8, 8), Extent2I(5, 5))
        self.assertEqual(mFoot.getBBox(), footprintBBox)

        fpImage = np.array(images)[:, 3:-3, 3:-3]
        mask = np.all(fpImage <= np.array(thresh)[:, None, None], axis=0)
        fpImage[:, mask] = 0
        self.assertFloatsAlmostEqual(mFoot.getImage(fill=0).image.array, fpImage)
        img = mFoot.getImage().image.array
        img[~np.isfinite(img)] = 1.1
        self.assertFloatsAlmostEqual(mFoot.getImage(fill=1.1).image.array, img)

    def testSlicing(self):
        self.assertIsInstance(self.mFoot["R"], HeavyFootprintF)
        self.assertIsInstance(self.mFoot[:], MultibandFootprint)

        self.assertEqual(self.mFoot["I"], self.mFoot["I"])
        self.assertEqual(self.mFoot[:"I"].filters, ("G", "R"))
        self.assertEqual(self.mFoot[:"I"].getBBox(), self.bbox)
        self.assertEqual(self.mFoot[["G", "I"]].filters, ("G", "I"))
        self.assertEqual(self.mFoot[["G", "I"]].getBBox(), self.bbox)

        with self.assertRaises(TypeError):
            self.mFoot["I", 4, 5]
            self.mFoot["I", :, :]
        with self.assertRaises(IndexError):
            self.mFoot[:, :, :]

    def testSpans(self):
        self.assertEqual(self.mFoot.getSpans(), self.spans)
        for footprint in self.mFoot.singles:
            self.assertEqual(footprint.getSpans(), self.spans)

    def testPeaks(self):
        self.verifyPeaks(self.peaks, self.footprint.getPeaks())
        for footprint in self.mFoot.singles:
            self.verifyPeaks(footprint.getPeaks(), self.peaks)