Ejemplo n.º 1
0
    def testDataAuxTelZWO(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTelZWO, 160, announcedDefocalDisInMm=0.5)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.0205)
        self.assertEqual(inst.getCamPixelSize(), 15.2e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               74.92690058,
                               places=7)
Ejemplo n.º 2
0
    def _runWep(self, imgIntraName, imgExtraName, offset, model):

        # Cut the donut image from input files
        centroidFindType = CentroidFindType.Otsu
        imgIntra = Image(centroidFindType=centroidFindType)
        imgExtra = Image(centroidFindType=centroidFindType)

        imgIntraPath = os.path.join(self.testImgDir, imgIntraName)
        imgExtraPath = os.path.join(self.testImgDir, imgExtraName)

        imgIntra.setImg(imageFile=imgIntraPath)

        imgExtra.setImg(imageFile=imgExtraPath)

        xIntra, yIntra, _ = imgIntra.getCenterAndR()
        imgIntraArray = imgIntra.getImg()[int(yIntra) - offset:int(yIntra) +
                                          offset,
                                          int(xIntra) - offset:int(xIntra) +
                                          offset, ]

        xExtra, yExtra, _ = imgExtra.getCenterAndR()
        imgExtraArray = imgExtra.getImg()[int(yExtra) - offset:int(yExtra) +
                                          offset,
                                          int(xExtra) - offset:int(xExtra) +
                                          offset, ]

        # Set the images
        fieldXY = (0, 0)
        imgCompIntra = CompensableImage(centroidFindType=centroidFindType)
        imgCompIntra.setImg(fieldXY, DefocalType.Intra, image=imgIntraArray)

        imgCompExtra = CompensableImage(centroidFindType=centroidFindType)
        imgCompExtra.setImg(fieldXY, DefocalType.Extra, image=imgExtraArray)

        # Calculate the wavefront error

        # Set the instrument
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")
        instAuxTel = Instrument(instDir)
        instAuxTel.config(CamType.AuxTel,
                          imgCompIntra.getImgSizeInPix(),
                          announcedDefocalDisInMm=0.8)

        # Set the algorithm
        algoFolderPath = os.path.join(getConfigDir(), "cwfs", "algo")
        algoAuxTel = Algorithm(algoFolderPath)
        algoAuxTel.config("exp", instAuxTel)
        algoAuxTel.runIt(imgCompIntra, imgCompExtra, model)

        return algoAuxTel.getZer4UpInNm()
Ejemplo n.º 3
0
    def testMakeMasks(self):

        donutStamp = DonutStamp(
            self.testStamps[0],
            lsst.geom.SpherePoint(0.0, 0.0, lsst.geom.degrees),
            lsst.geom.Point2D(2047.5, 2001.5),
            DefocalType.Extra.value,
            "R22_S11",
            "LSSTCam",
        )

        # Set up instrument
        instDataPath = os.path.join(getConfigDir(), "cwfs", "instData")
        inst = Instrument(instDataPath)
        donutWidth = 126
        inst.config(CamType.LsstCam, donutWidth)

        # Check that masks are empty at start
        np.testing.assert_array_equal(np.empty(shape=(0, 0)),
                                      donutStamp.mask_comp.getArray())
        np.testing.assert_array_equal(np.empty(shape=(0, 0)),
                                      donutStamp.mask_pupil.getArray())

        # Check masks after creation
        donutStamp.makeMasks(inst, "offAxis", 0, 1)
        self.assertEqual(afwImage.MaskX, type(donutStamp.mask_comp))
        self.assertEqual(afwImage.MaskX, type(donutStamp.mask_pupil))
        self.assertDictEqual({
            "BKGRD": 0,
            "DONUT": 1
        }, donutStamp.mask_comp.getMaskPlaneDict())
        self.assertDictEqual({
            "BKGRD": 0,
            "DONUT": 1
        }, donutStamp.mask_pupil.getMaskPlaneDict())
        maskC = donutStamp.mask_comp.getArray()
        maskP = donutStamp.mask_pupil.getArray()
        # Donut should match
        self.assertEqual(np.shape(maskC), (126, 126))
        self.assertEqual(np.shape(maskP), (126, 126))
        # Make sure not just an empty array
        self.assertTrue(np.sum(maskC) > 0.0)
        self.assertTrue(np.sum(maskP) > 0.0)
        # Donut at center of focal plane should be symmetric
        np.testing.assert_array_equal(maskC[:63], maskC[-63:][::-1])
        np.testing.assert_array_equal(maskP[:63], maskP[-63:][::-1])
Ejemplo n.º 4
0
class WfEstimator(object):
    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensableImage
            Intra-focal donut image.
        """

        return self.imgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensableImage
            Extra-focal donut image.
        """

        return self.imgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(
        self,
        solver="exp",
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        defocalDisInMm=1.5,
        sizeInPix=120,
        centroidFindType=CentroidFindType.RandomWalk,
        debugLevel=0,
    ):
        """Configure the TIE solver.

        Parameters
        ----------
        solver : str, optional
            Algorithm to solve the Poisson's equation in the transport of
            intensity equation (TIE). It can be "fft" or "exp" here. (the
            default is "exp".)
        camType : enum 'CamType', optional
            Camera type. (the default is CamType.LsstCam.)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis". (the
            default is "offAxis".)
        defocalDisInMm : float, optional
            Defocal distance in mm. (the default is 1.5.)
        sizeInPix : int, optional
            Wavefront image pixel size. (the default is 120.)
        centroidFindType : enum 'CentroidFindType', optional
            Algorithm to find the centroid of donut. (the default is
            CentroidFindType.RandomWalk.)
        debugLevel : int, optional
            Show the information under the running. If the value is higher,
            the information shows more. It can be 0, 1, 2, or 3. (the default
            is 0.)

        Raises
        ------
        ValueError
            Wrong Poisson solver name.
        ValueError
            Wrong optical model.
        """

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

        # Update the instrument name
        self.sizeInPix = int(sizeInPix)
        self.inst.config(camType,
                         self.sizeInPix,
                         announcedDefocalDisInMm=defocalDisInMm)

        self.algo.config(solver, self.inst, debugLevel=debugLevel)

        # Reset the centroid find algorithm if not the default one
        if centroidFindType != CentroidFindType.RandomWalk:
            self.imgIntra = CompensableImage(centroidFindType=centroidFindType)
            self.imgExtra = CompensableImage(centroidFindType=centroidFindType)

    def setImg(self, fieldXY, defocalType, image=None, imageFile=None):
        """Set the wavefront image.

        Parameters
        ----------
        fieldXY : tuple or list
            Position of donut on the focal plane in degree for intra- and
            extra-focal images.
        defocalType : enum 'DefocalType'
            Defocal type of image.
        image : numpy.ndarray, optional
            Array of image. (the default is None.)
        imageFile : str, optional
            Path of image file. (the default is None.)
        """

        if defocalType == DefocalType.Intra:
            img = self.imgIntra
        elif defocalType == DefocalType.Extra:
            img = self.imgExtra

        img.setImg(fieldXY, defocalType, image=image, imageFile=imageFile)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """Calculate the wavefront error.

        Parameters
        ----------
        tol : float, optional
            [description] (the default is 1e-3.)
        showZer : bool, optional
            Decide to show the annular Zernike polynomails or not. (the default
            is False.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)

        Returns
        -------
        numpy.ndarray
            Coefficients of Zernike polynomials (z4 - z22).

        Raises
        ------
        RuntimeError
            Input image shape is wrong.
        """

        # Check the image size
        for img in (self.imgIntra, self.imgExtra):
            d1, d2 = img.getImg().shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError(
                    "Input image shape is (%d, %d), not required (%d, %d)" %
                    (d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.imgIntra,
                        self.imgExtra,
                        self.opticalModel,
                        tol=tol)

        # Show the Zernikes Zn (n>=4)
        if showZer:
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.getZer4UpInNm()
Ejemplo n.º 5
0
    def _runWEP(
        self,
        instDir,
        algoFolderPath,
        useAlgorithm,
        imageFolderPath,
        intra_image_name,
        extra_image_name,
        fieldXY,
        opticalModel,
        showFig=False,
    ):

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # There is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        I1 = CompensableImage()
        I2 = CompensableImage()

        I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set the instrument
        inst = Instrument(instDir)
        inst.config(CamType.LsstCam,
                    I1.getImgSizeInPix(),
                    announcedDefocalDisInMm=1.0)

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, inst, debugLevel=0)

        # Plot the original wavefront images
        if showFig:
            plotImage(I1.image, title="intra image")
            plotImage(I2.image, title="extra image")

        # Run it
        algo.runIt(I1, I2, opticalModel, tol=1e-3)

        # Show the Zernikes Zn (n>=4)
        algo.outZer4Up(showPlot=False)

        # Plot the final conservated images and wavefront
        if showFig:
            plotImage(I1.image, title="Compensated intra image")
            plotImage(I2.image, title="Compensated extra image")

            # Plot the Wavefront
            plotImage(algo.wcomp, title="Final wavefront")
            plotImage(
                algo.wcomp,
                title="Final wavefront with pupil mask applied",
                mask=algo.pMask,
            )

        # Return the Zernikes Zn (n>=4)
        return algo.getZer4UpInNm()
Ejemplo n.º 6
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam,
                         self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)

    def testGetDebugLevel(self):

        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testSetDebugLevel(self):

        self.algoExp.config("exp", self.inst, debugLevel=3)
        self.assertEqual(self.algoExp.getDebugLevel(), 3)

        self.algoExp.setDebugLevel(0)
        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testGetZer4UpInNm(self):

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertTrue(isinstance(zer4UpNm, np.ndarray))

    def testGetPoissonSolverName(self):

        self.assertEqual(self.algoExp.getPoissonSolverName(), "exp")
        self.assertEqual(self.algoFft.getPoissonSolverName(), "fft")

    def testGetNumOfZernikes(self):

        self.assertEqual(self.algoExp.getNumOfZernikes(), 22)
        self.assertEqual(self.algoFft.getNumOfZernikes(), 22)

    def testGetZernikeTerms(self):

        zTerms = self.algoExp.getZernikeTerms()
        self.assertTrue(type(zTerms[0]), int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())
        self.assertEqual(zTerms[1], 1)
        self.assertEqual(zTerms[-1], self.algoExp.getNumOfZernikes() - 1)

        zTerms = self.algoFft.getZernikeTerms()
        self.assertTrue(type(zTerms[0]), int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())

    def testGetObsOfZernikes(self):

        self.assertEqual(self.algoExp.getObsOfZernikes(),
                         self.inst.getObscuration())
        self.assertEqual(self.algoFft.getObsOfZernikes(),
                         self.inst.getObscuration())

    def testGetNumOfOuterItr(self):

        self.assertEqual(self.algoExp.getNumOfOuterItr(), 14)
        self.assertEqual(self.algoFft.getNumOfOuterItr(), 14)

    def testGetNumOfInnerItr(self):

        self.assertEqual(self.algoFft.getNumOfInnerItr(), 6)

    def testGetFeedbackGain(self):

        self.assertEqual(self.algoExp.getFeedbackGain(), 0.6)
        self.assertEqual(self.algoFft.getFeedbackGain(), 0.6)

    def testGetOffAxisPolyOrder(self):

        self.assertEqual(self.algoExp.getOffAxisPolyOrder(), 10)
        self.assertEqual(self.algoFft.getOffAxisPolyOrder(), 10)

    def testGetCompensatorMode(self):

        self.assertEqual(self.algoExp.getCompensatorMode(), "zer")
        self.assertEqual(self.algoFft.getCompensatorMode(), "zer")

    def testGetCompSequence(self):

        compSequence = self.algoExp.getCompSequence()
        self.assertTrue(isinstance(compSequence, np.ndarray))
        self.assertEqual(compSequence.dtype, int)
        self.assertEqual(len(compSequence), self.algoExp.getNumOfOuterItr())
        self.assertEqual(compSequence[0], 4)
        self.assertEqual(compSequence[-1], 22)

        compSequence = self.algoFft.getCompSequence()
        self.assertEqual(len(compSequence), self.algoFft.getNumOfOuterItr())

    def testGetBoundaryThickness(self):

        self.assertEqual(self.algoExp.getBoundaryThickness(), 8)
        self.assertEqual(self.algoFft.getBoundaryThickness(), 1)

    def testGetFftDimension(self):

        self.assertEqual(self.algoFft.getFftDimension(), 128)

    def testGetSignalClipSequence(self):

        sumclipSequence = self.algoFft.getSignalClipSequence()
        self.assertTrue(isinstance(sumclipSequence, np.ndarray))
        self.assertEqual(len(sumclipSequence),
                         self.algoExp.getNumOfOuterItr() + 1)
        self.assertEqual(sumclipSequence[0], 0.33)
        self.assertEqual(sumclipSequence[-1], 0.51)

    def testGetMaskScalingFactor(self):

        self.assertAlmostEqual(self.algoExp.getMaskScalingFactor(),
                               1.0939,
                               places=4)
        self.assertAlmostEqual(self.algoFft.getMaskScalingFactor(),
                               1.0939,
                               places=4)

    def testGetWavefrontMapEstiInIter0(self):

        self.assertRaises(RuntimeError, self.algoExp.getWavefrontMapEsti)

    def testGetWavefrontMapEstiAndResidual(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        wavefrontMapEsti = self.algoExp.getWavefrontMapEsti()
        wavefrontMapEsti[np.isnan(wavefrontMapEsti)] = 0
        self.assertGreater(np.sum(np.abs(wavefrontMapEsti)), 4.8e-4)

        wavefrontMapResidual = self.algoExp.getWavefrontMapResidual()
        wavefrontMapResidual[np.isnan(wavefrontMapResidual)] = 0
        self.assertLess(np.sum(np.abs(wavefrontMapResidual)), 2.5e-6)

    def testItr0(self):

        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertEqual(
            np.sum(np.abs(np.rint(zer4UpNm) - self._getAnsItr0())), 0)

    def _getAnsItr0(self):

        return [
            31,
            -69,
            -21,
            84,
            44,
            -53,
            48,
            -146,
            6,
            10,
            13,
            -5,
            1,
            -12,
            -8,
            7,
            0,
            -6,
            11,
        ]

    def testNextItrWithOneIter(self):

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=1)

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertEqual(
            np.sum(np.abs(np.rint(zer4UpNm) - self._getAnsItr0())), 0)

    def testNextItrWithTwoIter(self):

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)
        zer4UpNm = self.algoExp.getZer4UpInNm()

        ansRint = [
            40,
            -80,
            -18,
            92,
            44.0,
            -52,
            54,
            -146,
            5,
            10,
            15,
            -3,
            -0,
            -12,
            -8,
            7,
            1,
            -3,
            12,
        ]
        self.assertEqual(np.sum(np.abs(np.rint(zer4UpNm) - ansRint)), 0)

    def testIter0AndNextIterToCheckReset(self):

        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp1 = self.algoExp.getZer4UpInNm()

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)

        # itr0() should reset the images and ignore the effect from nextItr()
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp2 = self.algoExp.getZer4UpInNm()

        difference = np.sum(np.abs(tmp1 - tmp2))
        self.assertEqual(difference, 0)

    def testRunItOfExp(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)

    def testResetAfterFullCalc(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY,
                       self.I1.getDefocalType(),
                       image=self.I1.getImgInit())
        self.I2.setImg(fieldXY,
                       self.I2.getDefocalType(),
                       image=self.I2.getImgInit())
        self.algoExp.reset()

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)

    def testRunItOfFft(self):

        self.algoFft.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoFft.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)
class CompensationImageDecoratorTest(unittest.TestCase):
    """Test the CompensationImageDecorator class."""

    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()
        
        # Define the instrument folder
        instruFolder = os.path.join(modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the instrument name
        instruName = "lsst"
        sensorSamples = 120

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, sensorSamples)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        self.fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(modulePath, "tests", "testData",
                                     "testImages", "validation",
                                     "LSST_NE_SN25_z11_0.25_exp.txt")
        self.zcCol = np.loadtxt(zcAnsFilePath)

    def testFunc(self):

        # Declare the CompensationImageDecorator
        wfsImg = CompensationImageDecorator()

        # Test to set the image
        wfsImg.setImg(self.fieldXY, imageFile=self.imgFilePathIntra,
                      atype="intra")
        self.assertEqual(wfsImg.image.shape, (120, 120))
        self.assertEqual(wfsImg.fieldX, self.fieldXY[0])
        self.assertEqual(wfsImg.fieldY, self.fieldXY[1])

        # Test to update the image0
        wfsImg.updateImage0()
        self.assertEqual(np.sum(np.abs(wfsImg.image0-wfsImg.image)), 0)

        # Test to get the off axis correction
        offAxisCorrOrder = 10
        instDir = os.path.join(self.inst.instDir, self.inst.instName)
        wfsImg.getOffAxisCorr(instDir, offAxisCorrOrder)
        self.assertEqual(wfsImg.offAxisOffset, 0.001)
        self.assertEqual(wfsImg.offAxis_coeff.shape, (4, 66))
        self.assertAlmostEqual(wfsImg.offAxis_coeff[0, 0], -2.6362089*1e-3)

        # Test to make the mask list
        model = "paraxial"
        masklist = wfsImg.makeMaskList(self.inst, model)
        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0]])
        self.assertEqual(np.sum(np.abs(masklist-masklistAns)), 0)

        model = "offAxis"
        masklist = wfsImg.makeMaskList(self.inst, model)
        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0], 
                                [-0.21240585, -0.21240585, 1.2300922, 1], 
                                [-0.08784336, -0.08784336, 0.55802573, 0]])
        self.assertAlmostEqual(np.sum(np.abs(masklist-masklistAns)), 0)

        # Test the fuction to make the mask

        # The unit of boundary_thickness (boundaryT) is pixel
        boundaryT = 8
        maskScalingFactorLocal = 1
        model = "offAxis"
        wfsImg.makeMask(self.inst, model, boundaryT, maskScalingFactorLocal)
        self.assertEqual(wfsImg.pMask.shape, wfsImg.image.shape)
        self.assertEqual(wfsImg.cMask.shape, wfsImg.image.shape)
        self.assertEqual(np.sum(np.abs(wfsImg.cMask - wfsImg.pMask)), 3001)

        # Test the function of image co-center
        wfsImg.imageCoCenter(self.inst)
        xc, yc = wfsImg.getCenterAndR_ef()[0:2]
        self.assertEqual(int(xc), 63)
        self.assertEqual(int(yc), 63)

    def testFuncCompensation(self):

        # Generate a fake algorithm class
        algo = tempAlgo()
        algo.parameter = {"numTerms": 22, "offAxisPolyOrder": 10, "zobsR": 0.61}

        # Test the function of image compensation
        boundaryT = 8
        offAxisCorrOrder = 10
        zcCol = np.zeros(22)
        zcCol[3:] = self.zcCol*1e-9
        instDir = os.path.join(self.inst.instDir, self.inst.instName)

        wfsImgIntra = CompensationImageDecorator()
        wfsImgExtra = CompensationImageDecorator()
        wfsImgIntra.setImg(self.fieldXY, imageFile=self.imgFilePathIntra,
                           atype="intra")
        wfsImgExtra.setImg(self.fieldXY, imageFile=self.imgFilePathExtra,
                           atype="extra")

        for wfsImg in [wfsImgIntra, wfsImgExtra]:
            wfsImg.makeMask(self.inst, self.opticalModel, boundaryT, 1)
            wfsImg.getOffAxisCorr(instDir, offAxisCorrOrder)
            wfsImg.imageCoCenter(self.inst)
            wfsImg.compensate(self.inst, algo, zcCol, self.opticalModel)

        # Get the common region
        binaryImgIntra = wfsImgIntra.getCenterAndR_ef()[3]
        binaryImgExtra = wfsImgExtra.getCenterAndR_ef()[3]
        binaryImg = binaryImgIntra + binaryImgExtra
        binaryImg[binaryImg < 2] = 0
        binaryImg = binaryImg / 2

        # Calculate the difference
        intraImg = wfsImgIntra.getImg()
        extraImg = wfsImgExtra.getImg()
        res = np.sum(np.abs(intraImg - extraImg) * binaryImg)
        self.assertLess(res, 500)
Ejemplo n.º 8
0
class TestCompensableImage(unittest.TestCase):
    """Test the CompensableImage class."""
    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()

        # Define the instrument folder
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        # Define the instrument name
        dimOfDonutOnSensor = 120

        self.inst = Instrument(instDir)
        self.inst.config(CamType.LsstCam,
                         dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.0)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # This is the position of donut on the focal plane in degree
        self.fieldXY = (1.185, 1.185)

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(
            modulePath,
            "tests",
            "testData",
            "testImages",
            "validation",
            "simulation",
            "LSST_NE_SN25_z11_0.25_exp.txt",
        )
        self.zcCol = np.loadtxt(zcAnsFilePath)

        self.wfsImg = CompensableImage()

    def testGetDefocalType(self):

        defocalType = self.wfsImg.getDefocalType()
        self.assertEqual(defocalType, DefocalType.Intra)

    def testGetImgObj(self):

        imgObj = self.wfsImg.getImgObj()
        self.assertTrue(isinstance(imgObj, Image))

    def testGetImg(self):

        img = self.wfsImg.getImg()
        self.assertTrue(isinstance(img, np.ndarray))
        self.assertEqual(len(img), 0)

    def testGetImgSizeInPix(self):

        imgSizeInPix = self.wfsImg.getImgSizeInPix()
        self.assertEqual(imgSizeInPix, 0)

    def testGetOffAxisCoeff(self):

        offAxisCoeff, offAxisOffset = self.wfsImg.getOffAxisCoeff()
        self.assertTrue(isinstance(offAxisCoeff, np.ndarray))
        self.assertEqual(len(offAxisCoeff), 0)
        self.assertEqual(offAxisOffset, 0.0)

    def testGetImgInit(self):

        imgInit = self.wfsImg.getImgInit()
        self.assertEqual(imgInit, None)

    def testIsCaustic(self):

        self.assertFalse(self.wfsImg.isCaustic())

    def testGetPaddedMask(self):

        pMask = self.wfsImg.getPaddedMask()
        self.assertEqual(len(pMask), 0)
        self.assertEqual(pMask.dtype, int)

    def testGetNonPaddedMask(self):

        cMask = self.wfsImg.getNonPaddedMask()
        self.assertEqual(len(cMask), 0)
        self.assertEqual(cMask.dtype, int)

    def testGetFieldXY(self):

        fieldX, fieldY = self.wfsImg.getFieldXY()
        self.assertEqual(fieldX, 0)
        self.assertEqual(fieldY, 0)

    def testSetImg(self):

        self._setIntraImg()
        self.assertEqual(self.wfsImg.getImg().shape, (120, 120))

    def _setIntraImg(self):

        self.wfsImg.setImg(self.fieldXY,
                           DefocalType.Intra,
                           imageFile=self.imgFilePathIntra)

    def testUpdateImage(self):

        self._setIntraImg()

        newImg = np.random.rand(5, 5)
        self.wfsImg.updateImage(newImg)

        self.assertTrue(np.all(self.wfsImg.getImg() == newImg))

    def testUpdateImgInit(self):

        self._setIntraImg()

        self.wfsImg.updateImgInit()

        delta = np.sum(np.abs(self.wfsImg.getImgInit() - self.wfsImg.getImg()))
        self.assertEqual(delta, 0)

    def testImageCoCenter(self):

        self._setIntraImg()

        self.wfsImg.imageCoCenter(self.inst)

        xc, yc = self.wfsImg.getImgObj().getCenterAndR()[0:2]
        self.assertEqual(int(xc), 63)
        self.assertEqual(int(yc), 63)

    def testCompensate(self):

        # Generate a fake algorithm class
        algo = TempAlgo()

        # Test the function of image compensation
        boundaryT = 8
        offAxisCorrOrder = 10
        zcCol = np.zeros(22)
        zcCol[3:] = self.zcCol * 1e-9

        wfsImgIntra = CompensableImage()
        wfsImgExtra = CompensableImage()
        wfsImgIntra.setImg(
            self.fieldXY,
            DefocalType.Intra,
            imageFile=self.imgFilePathIntra,
        )
        wfsImgExtra.setImg(self.fieldXY,
                           DefocalType.Extra,
                           imageFile=self.imgFilePathExtra)

        for wfsImg in [wfsImgIntra, wfsImgExtra]:
            wfsImg.makeMask(self.inst, self.opticalModel, boundaryT, 1)
            wfsImg.setOffAxisCorr(self.inst, offAxisCorrOrder)
            wfsImg.imageCoCenter(self.inst)
            wfsImg.compensate(self.inst, algo, zcCol, self.opticalModel)

        # Get the common region
        intraImg = wfsImgIntra.getImg()
        extraImg = wfsImgExtra.getImg()

        centroid = CentroidRandomWalk()
        binaryImgIntra = centroid.getImgBinary(intraImg)
        binaryImgExtra = centroid.getImgBinary(extraImg)

        binaryImg = binaryImgIntra + binaryImgExtra
        binaryImg[binaryImg < 2] = 0
        binaryImg = binaryImg / 2

        # Calculate the difference
        res = np.sum(np.abs(intraImg - extraImg) * binaryImg)
        self.assertLess(res, 500)

    def testCenterOnProjection(self):

        template = self._prepareGaussian2D(100, 1)

        dx = 2
        dy = 8
        img = np.roll(np.roll(template, dx, axis=1), dy, axis=0)
        np.roll(np.roll(img, dx, axis=1), dy, axis=0)

        self.assertGreater(np.sum(np.abs(img - template)), 29)

        imgRecenter = self.wfsImg.centerOnProjection(img, template, window=20)
        self.assertLess(np.sum(np.abs(imgRecenter - template)), 1e-7)

    def _prepareGaussian2D(self, imgSize, sigma):

        x = np.linspace(-10, 10, imgSize)
        y = np.linspace(-10, 10, imgSize)

        xx, yy = np.meshgrid(x, y)

        return (1 / (2 * np.pi * sigma**2) *
                np.exp(-(xx**2 / (2 * sigma**2) + yy**2 / (2 * sigma**2))))

    def testSetOffAxisCorr(self):

        self._setIntraImg()

        offAxisCorrOrder = 10
        self.wfsImg.setOffAxisCorr(self.inst, offAxisCorrOrder)

        offAxisCoeff, offAxisOffset = self.wfsImg.getOffAxisCoeff()
        self.assertEqual(offAxisCoeff.shape, (4, 66))
        self.assertAlmostEqual(offAxisCoeff[0, 0], -2.6362089 * 1e-3)
        self.assertEqual(offAxisOffset, 0.001)

    def testMakeMaskListOfParaxial(self):

        self._setIntraImg()

        model = "paraxial"
        masklist = self.wfsImg.makeMaskList(self.inst, model)

        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0]])
        self.assertEqual(np.sum(np.abs(masklist - masklistAns)), 0)

    def testMakeMaskListOfOffAxis(self):

        self._setIntraImg()

        model = "offAxis"
        masklist = self.wfsImg.makeMaskList(self.inst, model)

        masklistAns = np.array([
            [0, 0, 1, 1],
            [0, 0, 0.61, 0],
            [-0.21240585, -0.21240585, 1.2300922, 1],
            [-0.08784336, -0.08784336, 0.55802573, 0],
        ])
        self.assertAlmostEqual(np.sum(np.abs(masklist - masklistAns)), 0)

    def testMakeMask(self):

        self._setIntraImg()

        boundaryT = 8
        maskScalingFactorLocal = 1
        model = "offAxis"
        self.wfsImg.makeMask(self.inst, model, boundaryT,
                             maskScalingFactorLocal)

        image = self.wfsImg.getImg()
        pMask = self.wfsImg.getPaddedMask()
        cMask = self.wfsImg.getNonPaddedMask()
        self.assertEqual(pMask.shape, image.shape)
        self.assertEqual(cMask.shape, image.shape)
        self.assertEqual(np.sum(np.abs(cMask - pMask)), 3001)
Ejemplo n.º 9
0
class TestInstrument(unittest.TestCase):
    """Test the Instrument class."""

    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam, self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)

    def testConfigWithUnsupportedCamType(self):

        self.assertRaises(ValueError, self.inst.config, CamType.LsstFamCam, 120)

    def testGetInstFileDir(self):

        instFileDir = self.inst.getInstFileDir()

        ansInstFileDir = os.path.join(self.instDir, "lsst")
        self.assertEqual(instFileDir, ansInstFileDir)

    def testGetAnnDefocalDisInMm(self):

        annDefocalDisInMm = self.inst.getAnnDefocalDisInMm()
        self.assertEqual(annDefocalDisInMm, 1.5)

    def testSetAnnDefocalDisInMm(self):

        annDefocalDisInMm = 2.0
        self.inst.setAnnDefocalDisInMm(annDefocalDisInMm)

        self.assertEqual(self.inst.getAnnDefocalDisInMm(), annDefocalDisInMm)

    def testGetInstFilePath(self):

        instFilePath = self.inst.getInstFilePath()
        self.assertTrue(os.path.exists(instFilePath))
        self.assertEqual(os.path.basename(instFilePath), "instParam.yaml")

    def testGetMaskOffAxisCorr(self):

        maskOffAxisCorr = self.inst.getMaskOffAxisCorr()
        self.assertEqual(maskOffAxisCorr.shape, (9, 5))
        self.assertEqual(maskOffAxisCorr[0, 0], 1.07)
        self.assertEqual(maskOffAxisCorr[2, 3], -0.090100858)

    def testGetDimOfDonutOnSensor(self):

        dimOfDonutOnSensor = self.inst.getDimOfDonutOnSensor()
        self.assertEqual(dimOfDonutOnSensor, self.dimOfDonutOnSensor)

    def testGetObscuration(self):

        obscuration = self.inst.getObscuration()
        self.assertEqual(obscuration, 0.61)

    def testGetFocalLength(self):

        focalLength = self.inst.getFocalLength()
        self.assertEqual(focalLength, 10.312)

    def testGetApertureDiameter(self):

        apertureDiameter = self.inst.getApertureDiameter()
        self.assertEqual(apertureDiameter, 8.36)

    def testGetDefocalDisOffset(self):

        defocalDisInM = self.inst.getDefocalDisOffset()

        # The answer is 1.5 mm
        self.assertEqual(defocalDisInM * 1e3, 1.5)

    def testGetCamPixelSize(self):

        camPixelSizeInM = self.inst.getCamPixelSize()

        # The answer is 10 um
        self.assertEqual(camPixelSizeInM * 1e6, 10)

    def testGetMarginalFocalLength(self):

        marginalFL = self.inst.getMarginalFocalLength()
        self.assertAlmostEqual(marginalFL, 9.4268, places=4)

    def testGetSensorFactor(self):

        sensorFactor = self.inst.getSensorFactor()
        self.assertAlmostEqual(sensorFactor, 0.98679, places=5)

    def testGetSensorCoor(self):

        xSensor, ySensor = self.inst.getSensorCoor()
        self.assertEqual(xSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(xSensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(xSensor[0, 1], -0.96212, places=5)

        self.assertEqual(ySensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(ySensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(ySensor[1, 0], -0.96212, places=5)

    def testGetSensorCoorAnnular(self):

        xoSensor, yoSensor = self.inst.getSensorCoorAnnular()
        self.assertEqual(xoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(xoSensor[0, 0]))
        self.assertTrue(np.isnan(xoSensor[60, 60]))

        self.assertEqual(yoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(yoSensor[0, 0]))
        self.assertTrue(np.isnan(yoSensor[60, 60]))
Ejemplo n.º 10
0
    def calcWfErr(
        self,
        centroidFindType,
        fieldXY,
        camType,
        algoName,
        announcedDefocalDisInMm,
        opticalModel,
        imageIntra=None,
        imageExtra=None,
        imageFileIntra=None,
        imageFileExtra=None,
    ):
        """Calculate the wavefront error.

        Parameters
        ----------
        centroidFindType : enum 'CentroidFindType'
            Algorithm to find the centroid of donut.
        fieldXY : tuple or list
            Position of donut on the focal plane in degree (field x, field y).
        camType : enum 'CamType'
            Camera type.
        algoName : str
            Algorithm configuration file to solve the Poisson's equation in the
            transport of intensity equation (TIE). It can be "fft" or "exp"
            here.
        announcedDefocalDisInMm : float
            Announced defocal distance in mm. It is noted that the defocal
            distance offset used in calculation might be different from this
            value.
        opticalModel : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        imageIntra : numpy.ndarray, optional
            Array of intra-focal image. (the default is None.)
        imageExtra : numpy.ndarray, optional
            Array of extra-focal image. (the default is None.)
        imageFileIntra : str, optional
            Path of intra-focal image file. (the default is None.)
        imageFileExtra : str, optional
            Path of extra-focal image file. (the default is None.)

        Returns
        -------
        numpy.ndarray
            Zernike polynomials of z4-zn in nm.
        """

        # Set the defocal images
        imgIntra = CompensableImage(centroidFindType=centroidFindType)
        imgExtra = CompensableImage(centroidFindType=centroidFindType)

        imgIntra.setImg(
            fieldXY, DefocalType.Intra, image=imageIntra, imageFile=imageFileIntra
        )
        imgExtra.setImg(
            fieldXY, DefocalType.Extra, image=imageExtra, imageFile=imageFileExtra
        )

        # Set the instrument
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")
        inst = Instrument(instDir)
        inst.config(
            camType,
            imgIntra.getImgSizeInPix(),
            announcedDefocalDisInMm=announcedDefocalDisInMm,
        )

        # Define the algorithm to be used.
        algoFolderPath = os.path.join(getConfigDir(), "cwfs", "algo")
        algo = Algorithm(algoFolderPath)
        algo.config(algoName, inst, debugLevel=0)

        # Run it
        algo.runIt(imgIntra, imgExtra, opticalModel, tol=1e-3)

        # Return the Zernikes Zn (n>=4)
        return algo.getZer4UpInNm()
Ejemplo n.º 11
0
class WfEstimator(object):
    def __init__(self, instruFolderPath, algoFolderPath):
        """

        Initialize the wavefront estimator class.

        Arguments:
            instruFolderPath {[str]} -- Path to instrument directory.
            algoFolderPath {[str]} -- Path to algorithm directory.
        """

        self.algo = Algorithm(algoFolderPath)
        self.inst = Instrument(instruFolderPath)
        self.ImgIntra = CompensationImageDecorator()
        self.ImgExtra = CompensationImageDecorator()
        self.opticalModel = ""

        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensationImageDecorator
            Intra-focal donut image.
        """

        return self.ImgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensationImageDecorator
            Extra-focal donut image.
        """

        return self.ImgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(self,
               solver="exp",
               instName="lsst",
               opticalModel="offAxis",
               defocalDisInMm=None,
               sizeInPix=120,
               debugLevel=0):
        """
        
        Configure the TIE solver.
        
        Keyword Arguments:
            solver {[str]} -- Algorithm to solve the Poisson's equation in the transport of 
                            intensity equation (TIE). It can be "fft" or "exp" here. 
                            (default: {"exp"})
            instName {[str]} -- Instrument name. It is "lsst" in the baseline. (default: {"lsst"})
            opticalModel {[str]} -- Optical model. It can be "paraxial", "onAxis", or "offAxis". 
                                    (default: {"offAxis"})
            defocalDisInMm {[float]} -- Defocal distance in mm. (default: {None})
            sizeInPix {[int]} -- Wavefront image pixel size. (default: {120}) 
            debugLevel {[int]} -- Show the information under the running. If the value is higher, 
                                    the information shows more. It can be 0, 1, 2, or 3. 
                                    (default: {0})

        Raises:
            ValueError -- Wrong instrument name.
            ValueError -- No intra-focal image.
            ValueError -- Wrong Poisson solver name.
            ValueError -- Wrong optical model.
        """

        # Check the inputs and assign the parameters used in the TIE
        # Need to change the way to hold the instance of Instrument and Algorithm

        # Update the isnstrument name
        if (defocalDisInMm is not None):
            instName = instName + str(int(10 * defocalDisInMm))

        if instName not in ("lsst", "lsst05", "lsst10", "lsst15", "lsst20",
                            "lsst25", "comcam10", "comcam15", "comcam20"):
            raise ValueError("Instrument can not be '%s'." % instName)

        # Set the available wavefront image size (n x n)
        self.sizeInPix = int(sizeInPix)

        # Configurate the instrument
        self.inst.config(instName, self.sizeInPix)

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)
        else:
            self.algo.config(solver, self.inst, debugLevel=debugLevel)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

    def setImg(self, fieldXY, image=None, imageFile=None, defocalType=None):
        """

        Set the wavefront image.

        Arguments:
            fieldXY {[float]} -- Position of donut on the focal plane in degree for intra- 
                                 and extra-focal images.

        Keyword Arguments:
            image {[float]} -- Array of image. (default: {None})
            imageFile {[str]} -- Path of image file. (default: {None})
            defocalType {[str]} -- Type of image. It should be "intra" or "extra". 
                                    (default: {None})

        Raises:
            ValueError -- Wrong defocal type.
        """

        # Check the defocal type
        if defocalType not in (self.ImgIntra.INTRA, self.ImgIntra.EXTRA):
            raise ValueError("Defocal type can not be '%s'." % defocalType)

        # Read the image and assign the type
        if (defocalType == self.ImgIntra.INTRA):
            self.ImgIntra.setImg(fieldXY,
                                 image=image,
                                 imageFile=imageFile,
                                 atype=defocalType)
        elif (defocalType == self.ImgIntra.EXTRA):
            self.ImgExtra.setImg(fieldXY,
                                 image=image,
                                 imageFile=imageFile,
                                 atype=defocalType)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """

        Calculate the wavefront error.

        Keyword Arguments:
            tol {number} -- Tolerance of difference of coefficients of Zk polynomials compared 
                            with the previours iteration. (default: {1e-3})
            showZer {bool} -- Decide to show the annular Zernike polynomails or not. 
                                (default: {False})
            showPlot {bool} -- Decide to show the plot or not. (default: {False})

        Returns:
            [float] -- Coefficients of Zernike polynomials (z4 - z22).
        """

        # Check the image size
        for img in (self.ImgIntra, self.ImgExtra):
            d1, d2 = img.image.shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError(
                    "Input image shape is (%d, %d), not required (%d, %d)" %
                    (d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.inst,
                        self.ImgIntra,
                        self.ImgExtra,
                        self.opticalModel,
                        tol=tol)

        # Show the Zernikes Zn (n>=4)
        if (showZer):
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.zer4UpNm

    def outParam(self, filename=None):
        """

        Put the information of images, instrument, and algorithm on terminal or file.

        Keyword Arguments:
            filename {[str]} -- Name of output file. (default: {None})
        """

        # Write the parameters into a file if needed.
        if (filename is not None):
            fout = open(filename, "w")
        else:
            fout = sys.stdout

        # Write the information of image and optical model
        if (self.ImgIntra.name is not None):
            fout.write("Intra image: \t %s\n" % self.ImgIntra.name)

        if (self.ImgIntra.fieldX is not None):
            fout.write("Intra image field in deg =(%6.3f, %6.3f)\n" %
                       (self.ImgIntra.fieldX, self.ImgIntra.fieldY))

        if (self.ImgExtra.name is not None):
            fout.write("Extra image: \t %s\n" % self.ImgExtra.name)

        if (self.ImgExtra.fieldX is not None):
            fout.write("Extra image field in deg =(%6.3f, %6.3f)\n" %
                       (self.ImgExtra.fieldX, self.ImgExtra.fieldY))

        if (self.opticalModel is not None):
            fout.write("Using optical model:\t %s\n" % self.opticalModel)

        # Read the instrument file
        if (self.inst.filename is not None):
            self.__readConfigFile(fout, self.inst, "instrument")

        # Read the algorithm file
        if (self.algo.filename is not None):
            self.__readConfigFile(fout, self.algo, "algorithm")

        # Close the file
        if (filename is not None):
            fout.close()

    def __readConfigFile(self, fout, config, configName):
        """
        
        Read the configuration file
        
        Arguments:
            fout {[file]} -- File instance.
            config {[metadata]} -- Instance of configuration. It is Instrument or Algorithm 
                                   here.
            configName {[str]} -- Name of configuration.
        """

        # Create a new line
        fout.write("\n")

        # Open the file
        fconfig = open(config.filename)
        fout.write("---" + configName +
                   " file: --- %s ----------\n" % config.filename)

        # Read the file information
        iscomment = False
        for line in fconfig:
            line = line.strip()
            if (line.startswith("###")):
                iscomment = ~iscomment
            if (not (line.startswith("#")) and (not iscomment)
                    and len(line) > 0):
                fout.write(line + "\n")

        # Close the file
        fconfig.close()
Ejemplo n.º 12
0
class WfEstimator(object):

    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensableImage
            Intra-focal donut image.
        """

        return self.imgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensableImage
            Extra-focal donut image.
        """

        return self.imgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(self, solver="exp", camType=CamType.LsstCam,
               opticalModel="offAxis", defocalDisInMm=None, sizeInPix=120,
               debugLevel=0):
        """Configure the TIE solver.

        Parameters
        ----------
        solver : str, optional
            Algorithm to solve the Poisson's equation in the transport of
            intensity equation (TIE). It can be "fft" or "exp" here. (the
            default is "exp".)
        camType : enum 'CamType'
            Camera type. (the default is CamType.LsstCam.)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis". (the
            default is "offAxis".)
        defocalDisInMm : float, optional
            Defocal distance in mm. (the default is None.)
        sizeInPix : int, optional
            Wavefront image pixel size. (the default is 120.)
        debugLevel : int, optional
            Show the information under the running. If the value is higher,
            the information shows more. It can be 0, 1, 2, or 3. (the default
            is 0.)

        Raises
        ------
        ValueError
            Wrong Poisson solver name.
        ValueError
            Wrong optical model.
        """

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

        # Update the isnstrument name
        if (defocalDisInMm is None):
            defocalDisInMm = 1.5

        self.sizeInPix = int(sizeInPix)
        self.inst.config(camType, self.sizeInPix,
                         announcedDefocalDisInMm=defocalDisInMm)

        self.algo.config(solver, self.inst, debugLevel=debugLevel)

    def setImg(self, fieldXY, defocalType, image=None, imageFile=None):
        """Set the wavefront image.

        Parameters
        ----------
        fieldXY : tuple or list
            Position of donut on the focal plane in degree for intra- and
            extra-focal images.
        defocalType : enum 'DefocalType'
            Defocal type of image.
        image : numpy.ndarray, optional
            Array of image. (the default is None.)
        imageFile : str, optional
            Path of image file. (the default is None.)
        """

        if (defocalType == DefocalType.Intra):
            img = self.imgIntra
        elif (defocalType == DefocalType.Extra):
            img = self.imgExtra

        img.setImg(fieldXY, defocalType, image=image, imageFile=imageFile)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """Calculate the wavefront error.

        Parameters
        ----------
        tol : float, optional
            [description] (the default is 1e-3.)
        showZer : bool, optional
            Decide to show the annular Zernike polynomails or not. (the default
            is False.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)

        Returns
        -------
        numpy.ndarray
            Coefficients of Zernike polynomials (z4 - z22).

        Raises
        ------
        RuntimeError
            Input image shape is wrong.
        """

        # Check the image size
        for img in (self.imgIntra, self.imgExtra):
            d1, d2 = img.getImg().shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError("Input image shape is (%d, %d), not required (%d, %d)" % (
                    d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.imgIntra, self.imgExtra, self.opticalModel, tol=tol)

        # Show the Zernikes Zn (n>=4)
        if (showZer):
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.getZer4UpInNm()
Ejemplo n.º 13
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""

    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam, self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)

    def testGetDebugLevel(self):

        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testSetDebugLevel(self):

        self.algoExp.config("exp", self.inst, debugLevel=3)
        self.assertEqual(self.algoExp.getDebugLevel(), 3)

        self.algoExp.setDebugLevel(0)
        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testGetZer4UpInNm(self):

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertTrue(isinstance(zer4UpNm, np.ndarray))

    def testGetPoissonSolverName(self):

        self.assertEqual(self.algoExp.getPoissonSolverName(), "exp")
        self.assertEqual(self.algoFft.getPoissonSolverName(), "fft")

    def testGetNumOfZernikes(self):

        self.assertEqual(self.algoExp.getNumOfZernikes(), 22)
        self.assertEqual(self.algoFft.getNumOfZernikes(), 22)

    def testGetZernikeTerms(self):

        zTerms = self.algoExp.getZernikeTerms()
        self.assertTrue(zTerms.dtype, int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())
        self.assertEqual(zTerms[0], 1)
        self.assertEqual(zTerms[-1], self.algoExp.getNumOfZernikes())

        zTerms = self.algoFft.getZernikeTerms()
        self.assertTrue(zTerms.dtype, int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())

    def testGetObsOfZernikes(self):

        self.assertEqual(self.algoExp.getObsOfZernikes(),
                         self.inst.getObscuration())
        self.assertEqual(self.algoFft.getObsOfZernikes(),
                         self.inst.getObscuration())

    def testGetNumOfOuterItr(self):

        self.assertEqual(self.algoExp.getNumOfOuterItr(), 14)
        self.assertEqual(self.algoFft.getNumOfOuterItr(), 14)

    def testGetNumOfInnerItr(self):

        self.assertEqual(self.algoFft.getNumOfInnerItr(), 6)

    def testGetFeedbackGain(self):

        self.assertEqual(self.algoExp.getFeedbackGain(), 0.6)
        self.assertEqual(self.algoFft.getFeedbackGain(), 0.6)

    def testGetOffAxisPolyOrder(self):

        self.assertEqual(self.algoExp.getOffAxisPolyOrder(), 10)
        self.assertEqual(self.algoFft.getOffAxisPolyOrder(), 10)

    def testGetCompensatorMode(self):

        self.assertEqual(self.algoExp.getCompensatorMode(), "zer")
        self.assertEqual(self.algoFft.getCompensatorMode(), "zer")

    def testGetCompSequence(self):

        compSequence = self.algoExp.getCompSequence()
        self.assertTrue(isinstance(compSequence, np.ndarray))
        self.assertEqual(compSequence.dtype, int)
        self.assertEqual(len(compSequence), self.algoExp.getNumOfOuterItr())
        self.assertEqual(compSequence[0], 4)
        self.assertEqual(compSequence[-1], 22)

        compSequence = self.algoFft.getCompSequence()
        self.assertEqual(len(compSequence), self.algoFft.getNumOfOuterItr())

    def testGetBoundaryThickness(self):

        self.assertEqual(self.algoExp.getBoundaryThickness(), 8)
        self.assertEqual(self.algoFft.getBoundaryThickness(), 1)

    def testGetFftDimension(self):

        self.assertEqual(self.algoFft.getFftDimension(), 128)

    def testGetSignalClipSequence(self):

        sumclipSequence = self.algoFft.getSignalClipSequence()
        self.assertTrue(isinstance(sumclipSequence, np.ndarray))
        self.assertEqual(len(sumclipSequence), self.algoExp.getNumOfOuterItr()+1)
        self.assertEqual(sumclipSequence[0], 0.33)
        self.assertEqual(sumclipSequence[-1], 0.51)

    def testGetMaskScalingFactor(self):

        self.assertAlmostEqual(self.algoExp.getMaskScalingFactor(), 1.0939,
                               places=4)
        self.assertAlmostEqual(self.algoFft.getMaskScalingFactor(), 1.0939,
                               places=4)

    def testRunItOfExp(self):

        # Test functions: itr0() and nextItr()
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp1 = self.algoExp.getZer4UpInNm()
        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp2 = self.algoExp.getZer4UpInNm()

        difference = np.sum(np.abs(tmp1-tmp2))
        self.assertEqual(difference, 0)

        # Run it
        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(Zk[7]), -192)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY, self.I1.getDefocalType(), image=self.I1.getImgInit())
        self.I2.setImg(fieldXY, self.I2.getDefocalType(), image=self.I2.getImgInit())
        self.algoExp.reset()
        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)
        Zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(Zk[7]), -192)

    def testRunItOfFft(self):

        self.algoFft.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoFft.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)
Ejemplo n.º 14
0
    def makeTemplate(
        self,
        sensorName,
        defocalType,
        imageSize,
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        pixelScale=0.2,
    ):
        """Make the donut template image.

        Parameters
        ----------
        sensorName : str
            The camera detector for which we want to make a template. Should
            be in "Rxx_Sxx" format.
        defocalType : enum 'DefocalType'
            The defocal state of the sensor.
        imageSize : int
            Size of template in pixels. The template will be a square.
        camType : enum 'CamType', optional
            Camera type. (Default is CamType.LsstCam)
        model : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
            (The default is "offAxis")
        pixelScale : float, optional
            The pixels to arcseconds conversion factor. (The default is 0.2)

        Returns
        -------
        numpy.ndarray [int]
            The donut template as a binary image.
        """

        configDir = getConfigDir()
        focalPlaneLayout = readPhoSimSettingData(configDir,
                                                 "focalplanelayout.txt",
                                                 "fieldCenter")

        pixelSizeInUm = float(focalPlaneLayout[sensorName][2])
        sizeXinPixel = int(focalPlaneLayout[sensorName][3])

        sensorXMicron, sensorYMicron = np.array(
            focalPlaneLayout[sensorName][:2], dtype=float)
        # Correction for wavefront sensors
        if sensorName in ("R44_S00_C0", "R00_S22_C1"):
            # Shift center to +x direction
            sensorXMicron = sensorXMicron + sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R44_S00_C1", "R00_S22_C0"):
            # Shift center to -x direction
            sensorXMicron = sensorXMicron - sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R04_S20_C1", "R40_S02_C0"):
            # Shift center to -y direction
            sensorYMicron = sensorYMicron - sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R04_S20_C0", "R40_S02_C1"):
            # Shift center to +y direction
            sensorYMicron = sensorYMicron + sizeXinPixel / 2 * pixelSizeInUm

        # Load Instrument parameters
        instDir = os.path.join(configDir, "cwfs", "instData")
        inst = Instrument(instDir)
        inst.config(camType, imageSize)

        # Create image for mask
        img = CompensableImage()

        # Convert pixel locations to degrees
        sensorXPixel = float(sensorXMicron) / pixelSizeInUm
        sensorYPixel = float(sensorYMicron) / pixelSizeInUm

        # Multiply by pixelScale then divide by 3600 for arcsec -> deg conversion
        sensorXDeg = sensorXPixel * pixelScale / 3600
        sensorYDeg = sensorYPixel * pixelScale / 3600
        fieldXY = [sensorXDeg, sensorYDeg]

        # Define position of donut at center of current sensor in degrees
        boundaryT = 0
        maskScalingFactorLocal = 1
        img.setImg(fieldXY,
                   defocalType,
                   image=np.zeros((imageSize, imageSize)))
        img.makeMask(inst, opticalModel, boundaryT, maskScalingFactorLocal)

        return img.getNonPaddedMask()
Ejemplo n.º 15
0
class TestInstrument(unittest.TestCase):
    """Test the Instrument class."""
    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam,
                         self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)

    def testConfigWithUnsupportedCamType(self):

        self.assertRaises(ValueError, self.inst.config, "NoThisCamType", 120)

    def testGetInstFileDir(self):

        instFileDir = self.inst.getInstFileDir()

        ansInstFileDir = os.path.join(self.instDir, "lsst")
        self.assertEqual(instFileDir, ansInstFileDir)

    def testGetAnnDefocalDisInMm(self):

        annDefocalDisInMm = self.inst.getAnnDefocalDisInMm()
        self.assertEqual(annDefocalDisInMm, 1.5)

    def testSetAnnDefocalDisInMm(self):

        annDefocalDisInMm = 2.0
        self.inst.setAnnDefocalDisInMm(annDefocalDisInMm)

        self.assertEqual(self.inst.getAnnDefocalDisInMm(), annDefocalDisInMm)

    def testGetInstFilePath(self):

        instFilePath = self.inst.getInstFilePath()
        self.assertTrue(os.path.exists(instFilePath))
        self.assertEqual(os.path.basename(instFilePath), "instParam.yaml")

    def testGetMaskOffAxisCorr(self):

        maskOffAxisCorr = self.inst.getMaskOffAxisCorr()
        self.assertEqual(maskOffAxisCorr.shape, (9, 5))
        self.assertEqual(maskOffAxisCorr[0, 0], 1.07)
        self.assertEqual(maskOffAxisCorr[2, 3], -0.090100858)

    def testGetDimOfDonutOnSensor(self):

        dimOfDonutOnSensor = self.inst.getDimOfDonutOnSensor()
        self.assertEqual(dimOfDonutOnSensor, self.dimOfDonutOnSensor)

    def testGetObscuration(self):

        obscuration = self.inst.getObscuration()
        self.assertEqual(obscuration, 0.61)

    def testGetFocalLength(self):

        focalLength = self.inst.getFocalLength()
        self.assertEqual(focalLength, 10.312)

    def testGetApertureDiameter(self):

        apertureDiameter = self.inst.getApertureDiameter()
        self.assertEqual(apertureDiameter, 8.36)

    def testGetDefocalDisOffset(self):

        defocalDisInM = self.inst.getDefocalDisOffset()

        # The answer is 1.5 mm
        self.assertEqual(defocalDisInM * 1e3, 1.5)

    def testGetCamPixelSize(self):

        camPixelSizeInM = self.inst.getCamPixelSize()

        # The answer is 10 um
        self.assertEqual(camPixelSizeInM * 1e6, 10)

    def testGetMarginalFocalLength(self):

        marginalFL = self.inst.getMarginalFocalLength()
        self.assertAlmostEqual(marginalFL, 9.4268, places=4)

    def testGetSensorFactor(self):

        sensorFactor = self.inst.getSensorFactor()
        self.assertAlmostEqual(sensorFactor, 0.98679, places=5)

    def testGetSensorCoor(self):

        xSensor, ySensor = self.inst.getSensorCoor()
        self.assertEqual(xSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(xSensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(xSensor[0, 1], -0.96212, places=5)

        self.assertEqual(ySensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(ySensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(ySensor[1, 0], -0.96212, places=5)

    def testGetSensorCoorAnnular(self):

        xoSensor, yoSensor = self.inst.getSensorCoorAnnular()
        self.assertEqual(xoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(xoSensor[0, 0]))
        self.assertTrue(np.isnan(xoSensor[60, 60]))

        self.assertEqual(yoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(yoSensor[0, 0]))
        self.assertTrue(np.isnan(yoSensor[60, 60]))

    def testCalcSizeOfDonutExpected(self):

        self.assertAlmostEqual(self.inst.calcSizeOfDonutExpected(),
                               121.60589604,
                               places=7)

    def testDataAuxTel(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTel, 160, announcedDefocalDisInMm=0.8)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.041 * 0.8)
        self.assertEqual(inst.getCamPixelSize(), 10.0e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               182.2222222,
                               places=7)

    def testDataAuxTelZWO(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTelZWO, 160, announcedDefocalDisInMm=0.5)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.0205)
        self.assertEqual(inst.getCamPixelSize(), 15.2e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               74.92690058,
                               places=7)
Ejemplo n.º 16
0
def runWEP(instDir, algoFolderPath, useAlgorithm, imageFolderPath,
           intra_image_name, extra_image_name, fieldXY, opticalModel,
           showFig=False):
    """Calculate the coefficients of normal/ annular Zernike polynomials based
    on the provided instrument, algorithm, and optical model.

    Parameters
    ----------
    instDir : str
        Path to instrument folder.
    algoFolderPath : str
        Path to algorithm folder.
    useAlgorithm : str
        Algorithm to solve the Poisson's equation in the transport of intensity
        equation (TIE). It can be "fft" or "exp" here.
    imageFolderPath : str
        Path to image folder.
    intra_image_name : str
        File name of intra-focal image.
    extra_image_name : str
        File name of extra-focal image.
    fieldXY : tuple
        Position of donut on the focal plane in degree for intra- and
        extra-focal images.
    opticalModel : str
        Optical model. It can be "paraxial", "onAxis", or "offAxis".
    showFig : bool, optional
        Show the wavefront image and compenstated image or not. (the default is
        False.)

    Returns
    -------
    numpy.ndarray
        Coefficients of Zernike polynomials (z4 - z22).
    """

    # Image files Path
    intra_image_file = os.path.join(imageFolderPath, intra_image_name)
    extra_image_file = os.path.join(imageFolderPath, extra_image_name)

    # There is the difference between intra and extra images
    # I1: intra_focal images, I2: extra_focal Images
    I1 = CompensableImage()
    I2 = CompensableImage()

    I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
    I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

    # Set the instrument
    inst = Instrument(instDir)
    inst.config(CamType.LsstCam, I1.getImgSizeInPix(),
                announcedDefocalDisInMm=1.0)

    # Define the algorithm to be used.
    algo = Algorithm(algoFolderPath)
    algo.config(useAlgorithm, inst, debugLevel=0)

    # Plot the original wavefront images
    if (showFig):
        plotImage(I1.image, title="intra image")
        plotImage(I2.image, title="extra image")

    # Run it
    algo.runIt(I1, I2, opticalModel, tol=1e-3)

    # Show the Zernikes Zn (n>=4)
    algo.outZer4Up(showPlot=False)

    # Plot the final conservated images and wavefront
    if (showFig):
        plotImage(I1.image, title="Compensated intra image")
        plotImage(I2.image, title="Compensated extra image")

        # Plot the Wavefront
        plotImage(algo.wcomp, title="Final wavefront")
        plotImage(algo.wcomp, title="Final wavefront with pupil mask applied",
                  mask=algo.pMask)

    # Return the Zernikes Zn (n>=4)
    return algo.getZer4UpInNm()
Ejemplo n.º 17
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the instrument folder
        instruFolder = os.path.join(self.modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the instrument name
        instruName = "lsst"

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "fft"

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        # self.I1 = Image.Image()
        # self.I2 = Image.Image()

        self.I1 = CompensationImageDecorator()
        self.I2 = CompensationImageDecorator()

        self.I1.setImg(fieldXY, imageFile=intra_image_file, atype="intra")
        self.I2.setImg(fieldXY, imageFile=extra_image_file, atype="extra")

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, self.I1.sizeinPix)

    def testExp(self):

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "exp"

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, self.inst, debugLevel=3)
        algo.setDebugLevel(0)
        self.assertEqual(algo.debugLevel, 0)

        # Test functions: itr0() and nextItr()
        algo.itr0(self.inst, self.I1, self.I2, self.opticalModel)
        tmp1 = algo.zer4UpNm
        algo.nextItr(self.inst, self.I1, self.I2, self.opticalModel, nItr=2)
        algo.itr0(self.inst, self.I1, self.I2, self.opticalModel)
        tmp2 = algo.zer4UpNm

        difference = np.sum(np.abs(tmp1 - tmp2))
        self.assertEqual(difference, 0)

        # Run it
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY, image=self.I1.image0, atype=self.I1.atype)
        self.I2.setImg(fieldXY, image=self.I2.image0, atype=self.I2.atype)
        algo.reset()
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)

    def testFFT(self):

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "fft"

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, self.inst, debugLevel=0)

        # Run it
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)
Ejemplo n.º 18
0
 def testInstrument(self):
     inst = Instrument(self.instruFolder)
     inst.config(self.instruName, 120)
     self.assertEqual(inst.parameter["sensorSamples"], 120)
Ejemplo n.º 19
0
    def makeTemplate(
        self,
        sensorName,
        defocalType,
        imageSize,
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        pixelScale=0.2,
    ):
        """Make the donut template image.

        Parameters
        ----------
        sensorName : str
            The camera detector for which we want to make a template. Should
            be in "Rxx_Sxx" format.
        defocalType : enum 'DefocalType'
            The defocal state of the sensor.
        imageSize : int
            Size of template in pixels. The template will be a square.
        camType : enum 'CamType', optional
            Camera type. (The default is CamType.LsstCam)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
            (The default is "offAxis")
        pixelScale : float, optional
            The pixels to arcseconds conversion factor. (The default is 0.2)

        Returns
        -------
        numpy.ndarray [int]
            The donut template as a binary image.

        Raises
        ------
        ValueError
            Camera type is not supported.
        """

        configDir = getConfigDir()

        # Load Instrument parameters
        instDir = os.path.join(configDir, "cwfs", "instData")
        inst = Instrument(instDir)

        if camType in (CamType.LsstCam, CamType.LsstFamCam, CamType.ComCam):
            inst.config(camType, imageSize)
            focalPlaneLayout = readPhoSimSettingData(configDir,
                                                     "focalplanelayout.txt",
                                                     "fieldCenter")

            pixelSizeInUm = float(focalPlaneLayout[sensorName][2])

            sensorXMicron, sensorYMicron = np.array(
                focalPlaneLayout[sensorName][:2], dtype=float)

        elif camType == CamType.AuxTel:
            # AuxTel only works with onAxis sources
            if opticalModel != "onAxis":
                raise ValueError(
                    str(f"Optical Model {opticalModel} not supported with AuxTel. "
                        + "Must use 'onAxis'."))
            # Defocal distance for Latiss in mm
            # for LsstCam can use the default
            # hence only need to set here
            announcedDefocalDisInMm = getDefocalDisInMm("auxTel")
            inst.config(camType, imageSize, announcedDefocalDisInMm)
            # load the info for auxTel
            pixelSizeInMeters = inst.getCamPixelSize()  # pixel size in meters.
            pixelSizeInUm = pixelSizeInMeters * 1e6

            camera = obs_lsst.Latiss.getCamera()
            sensorName = list(
                camera.getNameIter())[0]  # only one detector in latiss
            detector = camera.get(sensorName)
            xp, yp = detector.getCenter(
                cameraGeom.FOCAL_PLANE)  # center of CCD in mm

            # multiply by 1000 to for mm --> microns conversion
            sensorXMicron = yp * 1000
            sensorYMicron = xp * 1000

        else:
            raise ValueError("Camera type (%s) is not supported." % camType)

        # Create image for mask
        img = CompensableImage()

        # Convert pixel locations to degrees
        sensorXPixel = float(sensorXMicron) / pixelSizeInUm
        sensorYPixel = float(sensorYMicron) / pixelSizeInUm

        # Multiply by pixelScale then divide by 3600 for arcsec->deg conversion
        sensorXDeg = sensorXPixel * pixelScale / 3600
        sensorYDeg = sensorYPixel * pixelScale / 3600
        fieldXY = [sensorXDeg, sensorYDeg]

        # Define position of donut at center of current sensor in degrees
        boundaryT = 0
        maskScalingFactorLocal = 1
        img.setImg(fieldXY,
                   defocalType,
                   image=np.zeros((imageSize, imageSize)))
        img.makeMask(inst, opticalModel, boundaryT, maskScalingFactorLocal)

        return img.getNonPaddedMask()