Пример #1
0
    def __init__(self, algoDir):
        """Initialize the Algorithm class.

        Algorithm used to solve the transport of intensity equation to get
        normal/ annular Zernike polynomials.

        Parameters
        ----------
        algoDir : str
            Algorithm configuration directory.
        """

        self.algoDir = algoDir
        self.algoParamFile = ParamReader()

        self._inst = Instrument("")

        # Show the calculation message based on this value
        # 0 means no message will be showed
        self.debugLevel = 0

        # Image has the problem or not from the over-compensation
        self.caustic = False

        # Record the Zk coefficients in each outer-loop iteration
        # The actual total outer-loop iteration time is Num_of_outer_itr + 1
        self.converge = np.array([])

        # Current number of outer-loop iteration
        self.currentItr = 0

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = np.array([])

        # Converged wavefront.
        self.wcomp = np.array([])

        # Calculated wavefront in previous outer-loop iteration.
        self.West = np.array([])

        # Converged Zk coefficients
        self.zcomp = np.array([])

        # Calculated Zk coefficients in previous outer-loop iteration
        self.zc = np.array([])

        # Padded mask for use at the offset planes
        self.pMask = None

        # Non-padded mask corresponding to aperture
        self.cMask = None

        # Change the dimension of mask for fft to use
        self.pMaskPad = None
        self.cMaskPad = None
Пример #2
0
    def __init__(self, instDir, algoDir):

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0
Пример #3
0
    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam,
                         self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)
Пример #4
0
    def __init__(self, algoDir):

        self.algoDir = algoDir
        self.algoParamFile = ParamReader()

        self._inst = Instrument("")

        # Show the calculation message based on this value
        # 0 means no message will be showed
        self.debugLevel = 0

        # Image has the problem or not from the over-compensation
        self.caustic = False

        # Record the Zk coefficients in each outer-loop iteration
        # The actual total outer-loop iteration time is Num_of_outer_itr + 1
        self.converge = np.array([])

        # Current number of outer-loop iteration
        self.currentItr = 0

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = np.array([])

        # Converged wavefront.
        self.wcomp = np.array([])

        # Calculated wavefront in previous outer-loop iteration.
        self.West = np.array([])

        # Converged Zk coefficients
        self.zcomp = np.array([])

        # Calculated Zk coefficients in previous outer-loop iteration
        self.zc = np.array([])

        # Padded mask for use at the offset planes
        self.mask_comp = None

        # Non-padded mask corresponding to aperture
        self.mask_pupil = None

        # Change the dimension of mask for fft to use
        self.mask_comp_pad = None
        self.mask_pupil_pad = None

        # Cache annular Zernike evaluations
        self._zk = None

        # Cache evaluations of X and Y annular Zernike gradients
        self._dzkdx = None
        self._dzkdy = None
Пример #5
0
    def testDataAuxTelZWO(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTelZWO, 160, announcedDefocalDisInMm=0.5)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.0205)
        self.assertEqual(inst.getCamPixelSize(), 15.2e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               74.92690058,
                               places=7)
Пример #6
0
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the instrument folder
        instruFolder = os.path.join(self.modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the instrument name
        instruName = "lsst"

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "fft"

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        # self.I1 = Image.Image()
        # self.I2 = Image.Image()

        self.I1 = CompensationImageDecorator()
        self.I2 = CompensationImageDecorator()

        self.I1.setImg(fieldXY, imageFile=intra_image_file, atype="intra")
        self.I2.setImg(fieldXY, imageFile=extra_image_file, atype="extra")

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, self.I1.sizeinPix)
Пример #7
0
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam,
                         self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)
Пример #8
0
    def _runWep(self, imgIntraName, imgExtraName, offset, model):

        # Cut the donut image from input files
        centroidFindType = CentroidFindType.Otsu
        imgIntra = Image(centroidFindType=centroidFindType)
        imgExtra = Image(centroidFindType=centroidFindType)

        imgIntraPath = os.path.join(self.testImgDir, imgIntraName)
        imgExtraPath = os.path.join(self.testImgDir, imgExtraName)

        imgIntra.setImg(imageFile=imgIntraPath)

        imgExtra.setImg(imageFile=imgExtraPath)

        xIntra, yIntra, _ = imgIntra.getCenterAndR()
        imgIntraArray = imgIntra.getImg()[int(yIntra) - offset:int(yIntra) +
                                          offset,
                                          int(xIntra) - offset:int(xIntra) +
                                          offset, ]

        xExtra, yExtra, _ = imgExtra.getCenterAndR()
        imgExtraArray = imgExtra.getImg()[int(yExtra) - offset:int(yExtra) +
                                          offset,
                                          int(xExtra) - offset:int(xExtra) +
                                          offset, ]

        # Set the images
        fieldXY = (0, 0)
        imgCompIntra = CompensableImage(centroidFindType=centroidFindType)
        imgCompIntra.setImg(fieldXY, DefocalType.Intra, image=imgIntraArray)

        imgCompExtra = CompensableImage(centroidFindType=centroidFindType)
        imgCompExtra.setImg(fieldXY, DefocalType.Extra, image=imgExtraArray)

        # Calculate the wavefront error

        # Set the instrument
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")
        instAuxTel = Instrument(instDir)
        instAuxTel.config(CamType.AuxTel,
                          imgCompIntra.getImgSizeInPix(),
                          announcedDefocalDisInMm=0.8)

        # Set the algorithm
        algoFolderPath = os.path.join(getConfigDir(), "cwfs", "algo")
        algoAuxTel = Algorithm(algoFolderPath)
        algoAuxTel.config("exp", instAuxTel)
        algoAuxTel.runIt(imgCompIntra, imgCompExtra, model)

        return algoAuxTel.getZer4UpInNm()
Пример #9
0
    def __init__(self, instruFolderPath, algoFolderPath):
        """

        Initialize the wavefront estimator class.

        Arguments:
            instruFolderPath {[str]} -- Path to instrument directory.
            algoFolderPath {[str]} -- Path to algorithm directory.
        """

        self.algo = Algorithm(algoFolderPath)
        self.inst = Instrument(instruFolderPath)
        self.ImgIntra = CompensationImageDecorator()
        self.ImgExtra = CompensationImageDecorator()
        self.opticalModel = ""

        self.sizeInPix = 0
Пример #10
0
    def testMakeMasks(self):

        donutStamp = DonutStamp(
            self.testStamps[0],
            lsst.geom.SpherePoint(0.0, 0.0, lsst.geom.degrees),
            lsst.geom.Point2D(2047.5, 2001.5),
            DefocalType.Extra.value,
            "R22_S11",
            "LSSTCam",
        )

        # Set up instrument
        instDataPath = os.path.join(getConfigDir(), "cwfs", "instData")
        inst = Instrument(instDataPath)
        donutWidth = 126
        inst.config(CamType.LsstCam, donutWidth)

        # Check that masks are empty at start
        np.testing.assert_array_equal(np.empty(shape=(0, 0)),
                                      donutStamp.mask_comp.getArray())
        np.testing.assert_array_equal(np.empty(shape=(0, 0)),
                                      donutStamp.mask_pupil.getArray())

        # Check masks after creation
        donutStamp.makeMasks(inst, "offAxis", 0, 1)
        self.assertEqual(afwImage.MaskX, type(donutStamp.mask_comp))
        self.assertEqual(afwImage.MaskX, type(donutStamp.mask_pupil))
        self.assertDictEqual({
            "BKGRD": 0,
            "DONUT": 1
        }, donutStamp.mask_comp.getMaskPlaneDict())
        self.assertDictEqual({
            "BKGRD": 0,
            "DONUT": 1
        }, donutStamp.mask_pupil.getMaskPlaneDict())
        maskC = donutStamp.mask_comp.getArray()
        maskP = donutStamp.mask_pupil.getArray()
        # Donut should match
        self.assertEqual(np.shape(maskC), (126, 126))
        self.assertEqual(np.shape(maskP), (126, 126))
        # Make sure not just an empty array
        self.assertTrue(np.sum(maskC) > 0.0)
        self.assertTrue(np.sum(maskP) > 0.0)
        # Donut at center of focal plane should be symmetric
        np.testing.assert_array_equal(maskC[:63], maskC[-63:][::-1])
        np.testing.assert_array_equal(maskP[:63], maskP[-63:][::-1])
Пример #11
0
    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()

        # Define the instrument folder
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        # Define the instrument name
        dimOfDonutOnSensor = 120

        self.inst = Instrument(instDir)
        self.inst.config(CamType.LsstCam,
                         dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.0)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # This is the position of donut on the focal plane in degree
        self.fieldXY = (1.185, 1.185)

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(
            modulePath,
            "tests",
            "testData",
            "testImages",
            "validation",
            "simulation",
            "LSST_NE_SN25_z11_0.25_exp.txt",
        )
        self.zcCol = np.loadtxt(zcAnsFilePath)

        self.wfsImg = CompensableImage()
Пример #12
0
    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam, self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)
Пример #13
0
    def __init__(self, algoDir):
        """Initialize the Algorithm class.

        Algorithm used to solve the transport of intensity equation to get
        normal/ annular Zernike polynomials.

        Parameters
        ----------
        algoDir : str
            Algorithm configuration directory.
        """

        self.algoDir = algoDir
        self.algoParamFile = ParamReader()

        self._inst = Instrument("")

        # Show the calculation message based on this value
        # 0 means no message will be showed
        self.debugLevel = 0

        # Image has the problem or not from the over-compensation
        self.caustic = False

        # Record the Zk coefficients in each outer-loop iteration
        # The actual total outer-loop iteration time is Num_of_outer_itr + 1
        self.converge = np.array([])

        # Current number of outer-loop iteration
        self.currentItr = 0

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = np.array([])

        # Converged wavefront.
        self.wcomp = np.array([])

        # Calculated wavefront in previous outer-loop iteration.
        self.West = np.array([])

        # Converged Zk coefficients
        self.zcomp = np.array([])

        # Calculated Zk coefficients in previous outer-loop iteration
        self.zc = np.array([])

        # Padded mask for use at the offset planes
        self.pMask = None

        # Non-padded mask corresponding to aperture
        self.cMask = None

        # Change the dimension of mask for fft to use
        self.pMaskPad = None
        self.cMaskPad = None
Пример #14
0
    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0
    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()
        
        # Define the instrument folder
        instruFolder = os.path.join(modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the instrument name
        instruName = "lsst"
        sensorSamples = 120

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, sensorSamples)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        self.fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(modulePath, "tests", "testData",
                                     "testImages", "validation",
                                     "LSST_NE_SN25_z11_0.25_exp.txt")
        self.zcCol = np.loadtxt(zcAnsFilePath)
Пример #16
0
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam, self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)
Пример #17
0
    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0
Пример #18
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam,
                         self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)

    def testGetDebugLevel(self):

        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testSetDebugLevel(self):

        self.algoExp.config("exp", self.inst, debugLevel=3)
        self.assertEqual(self.algoExp.getDebugLevel(), 3)

        self.algoExp.setDebugLevel(0)
        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testGetZer4UpInNm(self):

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertTrue(isinstance(zer4UpNm, np.ndarray))

    def testGetPoissonSolverName(self):

        self.assertEqual(self.algoExp.getPoissonSolverName(), "exp")
        self.assertEqual(self.algoFft.getPoissonSolverName(), "fft")

    def testGetNumOfZernikes(self):

        self.assertEqual(self.algoExp.getNumOfZernikes(), 22)
        self.assertEqual(self.algoFft.getNumOfZernikes(), 22)

    def testGetZernikeTerms(self):

        zTerms = self.algoExp.getZernikeTerms()
        self.assertTrue(type(zTerms[0]), int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())
        self.assertEqual(zTerms[1], 1)
        self.assertEqual(zTerms[-1], self.algoExp.getNumOfZernikes() - 1)

        zTerms = self.algoFft.getZernikeTerms()
        self.assertTrue(type(zTerms[0]), int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())

    def testGetObsOfZernikes(self):

        self.assertEqual(self.algoExp.getObsOfZernikes(),
                         self.inst.getObscuration())
        self.assertEqual(self.algoFft.getObsOfZernikes(),
                         self.inst.getObscuration())

    def testGetNumOfOuterItr(self):

        self.assertEqual(self.algoExp.getNumOfOuterItr(), 14)
        self.assertEqual(self.algoFft.getNumOfOuterItr(), 14)

    def testGetNumOfInnerItr(self):

        self.assertEqual(self.algoFft.getNumOfInnerItr(), 6)

    def testGetFeedbackGain(self):

        self.assertEqual(self.algoExp.getFeedbackGain(), 0.6)
        self.assertEqual(self.algoFft.getFeedbackGain(), 0.6)

    def testGetOffAxisPolyOrder(self):

        self.assertEqual(self.algoExp.getOffAxisPolyOrder(), 10)
        self.assertEqual(self.algoFft.getOffAxisPolyOrder(), 10)

    def testGetCompensatorMode(self):

        self.assertEqual(self.algoExp.getCompensatorMode(), "zer")
        self.assertEqual(self.algoFft.getCompensatorMode(), "zer")

    def testGetCompSequence(self):

        compSequence = self.algoExp.getCompSequence()
        self.assertTrue(isinstance(compSequence, np.ndarray))
        self.assertEqual(compSequence.dtype, int)
        self.assertEqual(len(compSequence), self.algoExp.getNumOfOuterItr())
        self.assertEqual(compSequence[0], 4)
        self.assertEqual(compSequence[-1], 22)

        compSequence = self.algoFft.getCompSequence()
        self.assertEqual(len(compSequence), self.algoFft.getNumOfOuterItr())

    def testGetBoundaryThickness(self):

        self.assertEqual(self.algoExp.getBoundaryThickness(), 8)
        self.assertEqual(self.algoFft.getBoundaryThickness(), 1)

    def testGetFftDimension(self):

        self.assertEqual(self.algoFft.getFftDimension(), 128)

    def testGetSignalClipSequence(self):

        sumclipSequence = self.algoFft.getSignalClipSequence()
        self.assertTrue(isinstance(sumclipSequence, np.ndarray))
        self.assertEqual(len(sumclipSequence),
                         self.algoExp.getNumOfOuterItr() + 1)
        self.assertEqual(sumclipSequence[0], 0.33)
        self.assertEqual(sumclipSequence[-1], 0.51)

    def testGetMaskScalingFactor(self):

        self.assertAlmostEqual(self.algoExp.getMaskScalingFactor(),
                               1.0939,
                               places=4)
        self.assertAlmostEqual(self.algoFft.getMaskScalingFactor(),
                               1.0939,
                               places=4)

    def testGetWavefrontMapEstiInIter0(self):

        self.assertRaises(RuntimeError, self.algoExp.getWavefrontMapEsti)

    def testGetWavefrontMapEstiAndResidual(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        wavefrontMapEsti = self.algoExp.getWavefrontMapEsti()
        wavefrontMapEsti[np.isnan(wavefrontMapEsti)] = 0
        self.assertGreater(np.sum(np.abs(wavefrontMapEsti)), 4.8e-4)

        wavefrontMapResidual = self.algoExp.getWavefrontMapResidual()
        wavefrontMapResidual[np.isnan(wavefrontMapResidual)] = 0
        self.assertLess(np.sum(np.abs(wavefrontMapResidual)), 2.5e-6)

    def testItr0(self):

        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertEqual(
            np.sum(np.abs(np.rint(zer4UpNm) - self._getAnsItr0())), 0)

    def _getAnsItr0(self):

        return [
            31,
            -69,
            -21,
            84,
            44,
            -53,
            48,
            -146,
            6,
            10,
            13,
            -5,
            1,
            -12,
            -8,
            7,
            0,
            -6,
            11,
        ]

    def testNextItrWithOneIter(self):

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=1)

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertEqual(
            np.sum(np.abs(np.rint(zer4UpNm) - self._getAnsItr0())), 0)

    def testNextItrWithTwoIter(self):

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)
        zer4UpNm = self.algoExp.getZer4UpInNm()

        ansRint = [
            40,
            -80,
            -18,
            92,
            44.0,
            -52,
            54,
            -146,
            5,
            10,
            15,
            -3,
            -0,
            -12,
            -8,
            7,
            1,
            -3,
            12,
        ]
        self.assertEqual(np.sum(np.abs(np.rint(zer4UpNm) - ansRint)), 0)

    def testIter0AndNextIterToCheckReset(self):

        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp1 = self.algoExp.getZer4UpInNm()

        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)

        # itr0() should reset the images and ignore the effect from nextItr()
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp2 = self.algoExp.getZer4UpInNm()

        difference = np.sum(np.abs(tmp1 - tmp2))
        self.assertEqual(difference, 0)

    def testRunItOfExp(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)

    def testResetAfterFullCalc(self):

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY,
                       self.I1.getDefocalType(),
                       image=self.I1.getImgInit())
        self.I2.setImg(fieldXY,
                       self.I2.getDefocalType(),
                       image=self.I2.getImgInit())
        self.algoExp.reset()

        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)

    def testRunItOfFft(self):

        self.algoFft.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoFft.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)
Пример #19
0
class TestInstrument(unittest.TestCase):
    """Test the Instrument class."""
    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam,
                         self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)

    def testConfigWithUnsupportedCamType(self):

        self.assertRaises(ValueError, self.inst.config, "NoThisCamType", 120)

    def testGetInstFileDir(self):

        instFileDir = self.inst.getInstFileDir()

        ansInstFileDir = os.path.join(self.instDir, "lsst")
        self.assertEqual(instFileDir, ansInstFileDir)

    def testGetAnnDefocalDisInMm(self):

        annDefocalDisInMm = self.inst.getAnnDefocalDisInMm()
        self.assertEqual(annDefocalDisInMm, 1.5)

    def testSetAnnDefocalDisInMm(self):

        annDefocalDisInMm = 2.0
        self.inst.setAnnDefocalDisInMm(annDefocalDisInMm)

        self.assertEqual(self.inst.getAnnDefocalDisInMm(), annDefocalDisInMm)

    def testGetInstFilePath(self):

        instFilePath = self.inst.getInstFilePath()
        self.assertTrue(os.path.exists(instFilePath))
        self.assertEqual(os.path.basename(instFilePath), "instParam.yaml")

    def testGetMaskOffAxisCorr(self):

        maskOffAxisCorr = self.inst.getMaskOffAxisCorr()
        self.assertEqual(maskOffAxisCorr.shape, (9, 5))
        self.assertEqual(maskOffAxisCorr[0, 0], 1.07)
        self.assertEqual(maskOffAxisCorr[2, 3], -0.090100858)

    def testGetDimOfDonutOnSensor(self):

        dimOfDonutOnSensor = self.inst.getDimOfDonutOnSensor()
        self.assertEqual(dimOfDonutOnSensor, self.dimOfDonutOnSensor)

    def testGetObscuration(self):

        obscuration = self.inst.getObscuration()
        self.assertEqual(obscuration, 0.61)

    def testGetFocalLength(self):

        focalLength = self.inst.getFocalLength()
        self.assertEqual(focalLength, 10.312)

    def testGetApertureDiameter(self):

        apertureDiameter = self.inst.getApertureDiameter()
        self.assertEqual(apertureDiameter, 8.36)

    def testGetDefocalDisOffset(self):

        defocalDisInM = self.inst.getDefocalDisOffset()

        # The answer is 1.5 mm
        self.assertEqual(defocalDisInM * 1e3, 1.5)

    def testGetCamPixelSize(self):

        camPixelSizeInM = self.inst.getCamPixelSize()

        # The answer is 10 um
        self.assertEqual(camPixelSizeInM * 1e6, 10)

    def testGetMarginalFocalLength(self):

        marginalFL = self.inst.getMarginalFocalLength()
        self.assertAlmostEqual(marginalFL, 9.4268, places=4)

    def testGetSensorFactor(self):

        sensorFactor = self.inst.getSensorFactor()
        self.assertAlmostEqual(sensorFactor, 0.98679, places=5)

    def testGetSensorCoor(self):

        xSensor, ySensor = self.inst.getSensorCoor()
        self.assertEqual(xSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(xSensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(xSensor[0, 1], -0.96212, places=5)

        self.assertEqual(ySensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(ySensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(ySensor[1, 0], -0.96212, places=5)

    def testGetSensorCoorAnnular(self):

        xoSensor, yoSensor = self.inst.getSensorCoorAnnular()
        self.assertEqual(xoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(xoSensor[0, 0]))
        self.assertTrue(np.isnan(xoSensor[60, 60]))

        self.assertEqual(yoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(yoSensor[0, 0]))
        self.assertTrue(np.isnan(yoSensor[60, 60]))

    def testCalcSizeOfDonutExpected(self):

        self.assertAlmostEqual(self.inst.calcSizeOfDonutExpected(),
                               121.60589604,
                               places=7)

    def testDataAuxTel(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTel, 160, announcedDefocalDisInMm=0.8)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.041 * 0.8)
        self.assertEqual(inst.getCamPixelSize(), 10.0e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               182.2222222,
                               places=7)

    def testDataAuxTelZWO(self):

        inst = Instrument(self.instDir)
        inst.config(CamType.AuxTelZWO, 160, announcedDefocalDisInMm=0.5)

        self.assertEqual(inst.getObscuration(), 0.3525)
        self.assertEqual(inst.getFocalLength(), 21.6)
        self.assertEqual(inst.getApertureDiameter(), 1.2)
        self.assertEqual(inst.getDefocalDisOffset(), 0.0205)
        self.assertEqual(inst.getCamPixelSize(), 15.2e-6)
        self.assertAlmostEqual(inst.calcSizeOfDonutExpected(),
                               74.92690058,
                               places=7)
Пример #20
0
    def makeTemplate(
        self,
        sensorName,
        defocalType,
        imageSize,
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        pixelScale=0.2,
    ):
        """Make the donut template image.

        Parameters
        ----------
        sensorName : str
            The camera detector for which we want to make a template. Should
            be in "Rxx_Sxx" format.
        defocalType : enum 'DefocalType'
            The defocal state of the sensor.
        imageSize : int
            Size of template in pixels. The template will be a square.
        camType : enum 'CamType', optional
            Camera type. (Default is CamType.LsstCam)
        model : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
            (The default is "offAxis")
        pixelScale : float, optional
            The pixels to arcseconds conversion factor. (The default is 0.2)

        Returns
        -------
        numpy.ndarray [int]
            The donut template as a binary image.
        """

        configDir = getConfigDir()
        focalPlaneLayout = readPhoSimSettingData(configDir,
                                                 "focalplanelayout.txt",
                                                 "fieldCenter")

        pixelSizeInUm = float(focalPlaneLayout[sensorName][2])
        sizeXinPixel = int(focalPlaneLayout[sensorName][3])

        sensorXMicron, sensorYMicron = np.array(
            focalPlaneLayout[sensorName][:2], dtype=float)
        # Correction for wavefront sensors
        if sensorName in ("R44_S00_C0", "R00_S22_C1"):
            # Shift center to +x direction
            sensorXMicron = sensorXMicron + sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R44_S00_C1", "R00_S22_C0"):
            # Shift center to -x direction
            sensorXMicron = sensorXMicron - sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R04_S20_C1", "R40_S02_C0"):
            # Shift center to -y direction
            sensorYMicron = sensorYMicron - sizeXinPixel / 2 * pixelSizeInUm
        elif sensorName in ("R04_S20_C0", "R40_S02_C1"):
            # Shift center to +y direction
            sensorYMicron = sensorYMicron + sizeXinPixel / 2 * pixelSizeInUm

        # Load Instrument parameters
        instDir = os.path.join(configDir, "cwfs", "instData")
        inst = Instrument(instDir)
        inst.config(camType, imageSize)

        # Create image for mask
        img = CompensableImage()

        # Convert pixel locations to degrees
        sensorXPixel = float(sensorXMicron) / pixelSizeInUm
        sensorYPixel = float(sensorYMicron) / pixelSizeInUm

        # Multiply by pixelScale then divide by 3600 for arcsec -> deg conversion
        sensorXDeg = sensorXPixel * pixelScale / 3600
        sensorYDeg = sensorYPixel * pixelScale / 3600
        fieldXY = [sensorXDeg, sensorYDeg]

        # Define position of donut at center of current sensor in degrees
        boundaryT = 0
        maskScalingFactorLocal = 1
        img.setImg(fieldXY,
                   defocalType,
                   image=np.zeros((imageSize, imageSize)))
        img.makeMask(inst, opticalModel, boundaryT, maskScalingFactorLocal)

        return img.getNonPaddedMask()
Пример #21
0
    def calcWfErr(
        self,
        centroidFindType,
        fieldXY,
        camType,
        algoName,
        announcedDefocalDisInMm,
        opticalModel,
        imageIntra=None,
        imageExtra=None,
        imageFileIntra=None,
        imageFileExtra=None,
    ):
        """Calculate the wavefront error.

        Parameters
        ----------
        centroidFindType : enum 'CentroidFindType'
            Algorithm to find the centroid of donut.
        fieldXY : tuple or list
            Position of donut on the focal plane in degree (field x, field y).
        camType : enum 'CamType'
            Camera type.
        algoName : str
            Algorithm configuration file to solve the Poisson's equation in the
            transport of intensity equation (TIE). It can be "fft" or "exp"
            here.
        announcedDefocalDisInMm : float
            Announced defocal distance in mm. It is noted that the defocal
            distance offset used in calculation might be different from this
            value.
        opticalModel : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        imageIntra : numpy.ndarray, optional
            Array of intra-focal image. (the default is None.)
        imageExtra : numpy.ndarray, optional
            Array of extra-focal image. (the default is None.)
        imageFileIntra : str, optional
            Path of intra-focal image file. (the default is None.)
        imageFileExtra : str, optional
            Path of extra-focal image file. (the default is None.)

        Returns
        -------
        numpy.ndarray
            Zernike polynomials of z4-zn in nm.
        """

        # Set the defocal images
        imgIntra = CompensableImage(centroidFindType=centroidFindType)
        imgExtra = CompensableImage(centroidFindType=centroidFindType)

        imgIntra.setImg(
            fieldXY, DefocalType.Intra, image=imageIntra, imageFile=imageFileIntra
        )
        imgExtra.setImg(
            fieldXY, DefocalType.Extra, image=imageExtra, imageFile=imageFileExtra
        )

        # Set the instrument
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")
        inst = Instrument(instDir)
        inst.config(
            camType,
            imgIntra.getImgSizeInPix(),
            announcedDefocalDisInMm=announcedDefocalDisInMm,
        )

        # Define the algorithm to be used.
        algoFolderPath = os.path.join(getConfigDir(), "cwfs", "algo")
        algo = Algorithm(algoFolderPath)
        algo.config(algoName, inst, debugLevel=0)

        # Run it
        algo.runIt(imgIntra, imgExtra, opticalModel, tol=1e-3)

        # Return the Zernikes Zn (n>=4)
        return algo.getZer4UpInNm()
Пример #22
0
class Algorithm(object):

    def __init__(self, algoDir):
        """Initialize the Algorithm class.

        Algorithm used to solve the transport of intensity equation to get
        normal/ annular Zernike polynomials.

        Parameters
        ----------
        algoDir : str
            Algorithm configuration directory.
        """

        self.algoDir = algoDir
        self.algoParamFile = ParamReader()

        self._inst = Instrument("")

        # Show the calculation message based on this value
        # 0 means no message will be showed
        self.debugLevel = 0

        # Image has the problem or not from the over-compensation
        self.caustic = False

        # Record the Zk coefficients in each outer-loop iteration
        # The actual total outer-loop iteration time is Num_of_outer_itr + 1
        self.converge = np.array([])

        # Current number of outer-loop iteration
        self.currentItr = 0

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = np.array([])

        # Converged wavefront.
        self.wcomp = np.array([])

        # Calculated wavefront in previous outer-loop iteration.
        self.West = np.array([])

        # Converged Zk coefficients
        self.zcomp = np.array([])

        # Calculated Zk coefficients in previous outer-loop iteration
        self.zc = np.array([])

        # Padded mask for use at the offset planes
        self.pMask = None

        # Non-padded mask corresponding to aperture
        self.cMask = None

        # Change the dimension of mask for fft to use
        self.pMaskPad = None
        self.cMaskPad = None

    def reset(self):
        """Reset the calculation for the new input images with the same
        algorithm settings."""

        self.caustic = False
        self.converge = np.zeros(self.converge.shape)
        self.currentItr = 0
        self.zer4UpNm = np.zeros(self.zer4UpNm.shape)

        self.wcomp = np.zeros(self.wcomp.shape)
        self.West = np.zeros(self.West.shape)

        self.zcomp = np.zeros(self.zcomp.shape)
        self.zc = np.zeros(self.zc.shape)

        self.pMask = None
        self.cMask = None

        self.pMaskPad = None
        self.cMaskPad = None

    def config(self, algoName, inst, debugLevel=0):
        """Configure the algorithm to solve TIE.

        Parameters
        ----------
        algoName : str
            Algorithm configuration file to solve the Poisson's equation in the
            transport of intensity equation (TIE). It can be "fft" or "exp"
            here.
        inst : Instrument
            Instrument to use.
        debugLevel : int, optional
            Show the information under the running. If the value is higher, the
            information shows more. It can be 0, 1, 2, or 3. (the default is
            0.)
        """

        algoParamFilePath = os.path.join(self.algoDir, "%s.yaml" % algoName)
        self.algoParamFile.setFilePath(algoParamFilePath)

        self._inst = inst
        self.debugLevel = debugLevel

        self.caustic = False

        numTerms = self.getNumOfZernikes()
        outerItr = self.getNumOfOuterItr()
        self.converge = np.zeros((numTerms, outerItr + 1))

        self.currentItr = 0

        self.zer4UpNm = np.zeros(numTerms - 3)

        # Wavefront related parameters
        dimOfDonut = self._inst.getDimOfDonutOnSensor()
        self.wcomp = np.zeros((dimOfDonut, dimOfDonut))
        self.West = self.wcomp.copy()

        # Used in model basis ("zer").
        self.zcomp = np.zeros(numTerms)
        self.zc = self.zcomp.copy()

        # Mask related variables
        self.pMask = None
        self.cMask = None
        self.pMaskPad = None
        self.cMaskPad = None

    def setDebugLevel(self, debugLevel):
        """Set the debug level.

        If the value is higher, the information shows more. It can be 0, 1, 2,
        or 3.

        Parameters
        ----------
        debugLevel : int
            Show the information under the running.
        """

        self.debugLevel = int(debugLevel)

    def getDebugLevel(self):
        """Get the debug level.

        If the value is higher, the information shows more. It can be 0, 1, 2,
        or 3.

        Returns
        -------
        int
            Debug level.
        """

        return self.debugLevel

    def getZer4UpInNm(self):
        """Get the coefficients of Zernike polynomials of z4-zn in nm.

        Returns
        -------
        numpy.ndarray
            Zernike polynomials of z4-zn in nm.
        """

        return self.zer4UpNm

    def getPoissonSolverName(self):
        """Get the method name to solve the Poisson equation.

        Returns
        -------
        str
            Method name to solve the Poisson equation.
        """

        return self.algoParamFile.getSetting("poissonSolver")

    def getNumOfZernikes(self):
        """Get the maximum number of Zernike polynomials supported.

        Returns
        -------
        int
            Maximum number of Zernike polynomials supported.
        """

        return int(self.algoParamFile.getSetting("numOfZernikes"))

    def getZernikeTerms(self):
        """Get the Zernike terms in using.

        Returns
        -------
        numpy.ndarray
            Zernkie terms in using.
        """

        numTerms = self.getNumOfZernikes()
        zTerms = np.arange(numTerms) + 1

        return zTerms

    def getObsOfZernikes(self):
        """Get the obscuration of annular Zernike polynomials.

        Returns
        -------
        float
            Obscuration of annular Zernike polynomials
        """

        zobsR = self.algoParamFile.getSetting("obsOfZernikes")
        if (zobsR == 1):
            zobsR = self._inst.getObscuration()

        return float(zobsR)

    def getNumOfOuterItr(self):
        """Get the number of outer loop iteration.

        Returns
        -------
        int
            Number of outer loop iteration.
        """

        return int(self.algoParamFile.getSetting("numOfOuterItr"))

    def getNumOfInnerItr(self):
        """Get the number of inner loop iteration.

        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        int
            Number of inner loop iteration.
        """

        return int(self.algoParamFile.getSetting("numOfInnerItr"))

    def getFeedbackGain(self):
        """Get the gain value used in the outer loop iteration.

        Returns
        -------
        float
            Gain value used in the outer loop iteration.
        """

        return self.algoParamFile.getSetting("feedbackGain")

    def getOffAxisPolyOrder(self):
        """Get the number of polynomial order supported in off-axis correction.

        Returns
        -------
        int
            Number of polynomial order supported in off-axis correction.
        """

        return int(self.algoParamFile.getSetting("offAxisPolyOrder"))

    def getCompensatorMode(self):
        """Get the method name to compensate the wavefront by wavefront error.

        Returns
        -------
        str
            Method name to compensate the wavefront by wavefront error.
        """

        return self.algoParamFile.getSetting("compensatorMode")

    def getCompSequence(self):
        """Get the compensated sequence of Zernike order for each iteration.

        Returns
        -------
        numpy.ndarray[int]
            Compensated sequence of Zernike order for each iteration.
        """

        compSequenceFromFile = self.algoParamFile.getSetting("compSequence")
        compSequence = np.array(compSequenceFromFile, dtype=int)

        # If outerItr is large, and compSequence is too small,
        # the rest in compSequence will be filled.
        # This is used in the "zer" method.
        outerItr = self.getNumOfOuterItr()
        compSequence = self._extend1dArray(compSequence, outerItr)
        compSequence = compSequence.astype(int)

        return compSequence

    def _extend1dArray(self, origArray, targetLength):
        """Extend the 1D original array to the taget length.

        The extended value will be the final element of original array. Nothing
        will be done if the input array is not 1D or its length is less than
        the target.

        Parameters
        ----------
        origArray : numpy.ndarray
            Original array with 1 dimension.
        targetLength : int
            Target length of new extended array.

        Returns
        -------
        numpy.ndarray
            Extended 1D array.
        """

        if (len(origArray) < targetLength) and (origArray.ndim == 1):
            leftOver = np.ones(targetLength - len(origArray))
            extendArray = np.append(origArray, origArray[-1] * leftOver)
        else:
            extendArray = origArray

        return extendArray

    def getBoundaryThickness(self):
        """Get the boundary thickness that the computation mask extends beyond
        the pupil mask.

        It is noted that in Fast Fourier transform (FFT) algorithm, it is also
        the width of Neuman boundary where the derivative of the wavefront is
        set to zero

        Returns
        -------
        int
            Boundary thickness.
        """

        return int(self.algoParamFile.getSetting("boundaryThickness"))

    def getFftDimension(self):
        """Get the FFT pad dimension in pixel.

        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        int
            FFT pad dimention.
        """

        fftDim = int(self.algoParamFile.getSetting("fftDimension"))

        # Make sure the dimension is the order of multiple of 2
        if (fftDim == 999):
            dimToFit = self._inst.getDimOfDonutOnSensor()
        else:
            dimToFit = fftDim

        padDim = int(2**np.ceil(np.log2(dimToFit)))

        return padDim

    def getSignalClipSequence(self):
        """Get the signal clip sequence.

        The number of values should be the number of compensation plus 1.
        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        numpy.ndarray
            Signal clip sequence.
        """

        sumclipSequenceFromFile = self.algoParamFile.getSetting("signalClipSequence")
        sumclipSequence = np.array(sumclipSequenceFromFile)

        # If outerItr is large, and sumclipSequence is too small, the rest in
        # sumclipSequence will be filled.
        # This is used in the "zer" method.
        targetLength = self.getNumOfOuterItr() + 1
        sumclipSequence = self._extend1dArray(sumclipSequence, targetLength)

        return sumclipSequence

    def getMaskScalingFactor(self):
        """Get the mask scaling factor for fast beam.

        Returns
        -------
        float
            Mask scaling factor for fast beam.
        """

        # m = R'*f/(l*R), R': radius of the no-aberration image
        focalLength = self._inst.getFocalLength()
        marginalFL = self._inst.getMarginalFocalLength()
        maskScalingFactor = focalLength / marginalFL

        return maskScalingFactor

    def itr0(self, I1, I2, model):
        """Calculate the wavefront and coefficients of normal/ annular Zernike
        polynomials in the first iteration time.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        """

        # Reset the iteration time of outer loop and decide to reset the
        # defocal images or not
        self._reset(I1, I2)

        # Solve the transport of intensity equation (TIE)
        self._singleItr(I1, I2, model)

    def runIt(self, I1, I2, model, tol=1e-3):
        """Calculate the wavefront error by solving the transport of intensity
        equation (TIE).

        The inner (for fft algorithm) and outer loops are used. The inner loop
        is to solve the Poisson's equation. The outer loop is to compensate the
        intra- and extra-focal images to mitigate the calculation of wavefront
        (e.g. S = -1/(delta Z) * (I1 - I2)/ (I1 + I2)).

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        tol : float, optional
            Tolerance of difference of coefficients of Zk polynomials compared
            with the previours iteration. (the default is 1e-3.)
        """

        # To have the iteration time initiated from global variable is to
        # distinguish the manually and automatically iteration processes.
        itr = self.currentItr
        while (itr <= self.getNumOfOuterItr()):
            stopItr = self._singleItr(I1, I2, model, tol)

            # Stop the iteration of outer loop if converged
            if (stopItr):
                break

            itr += 1

    def nextItr(self, I1, I2, model, nItr=1):
        """Run the outer loop iteration with the specific time defined in nItr.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        nItr : int, optional
            Outer loop iteration time. (the default is 1.)
        """

        #  Do the iteration
        ii = 0
        while (ii < nItr):
            self._singleItr(I1, I2, model)
            ii += 1

    def _singleItr(self, I1, I2, model, tol=1e-3):
        """Run the outer-loop with single iteration to solve the transport of
        intensity equation (TIE).

        This is to compensate the approximation of wavefront:
        S = -1/(delta Z) * (I1 - I2)/ (I1 + I2)).

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        tol : float, optional
            Tolerance of difference of coefficients of Zk polynomials compared
            with the previours iteration. (the default is 1e-3.)

        Returns
        -------
        bool
            Status of iteration.
        """

        # Use the zonal mode ("zer")
        compMode = self.getCompensatorMode()

        # Define the gain of feedbackGain
        feedbackGain = self.getFeedbackGain()

        # Set the pre-condition
        if (self.currentItr == 0):

            # Check this is the first time of running iteration or not
            if (I1.getImgInit() is None or I2.getImgInit() is None):

                # Check the image dimension
                if (I1.getImg().shape != I2.getImg().shape):
                    print("Error: The intra and extra image stamps need to be of same size.")
                    sys.exit()

                # Calculate the pupil mask (binary matrix) and related
                # parameters
                boundaryT = self.getBoundaryThickness()
                I1.makeMask(self._inst, model, boundaryT, 1)
                I2.makeMask(self._inst, model, boundaryT, 1)
                self._makeMasterMask(I1, I2, self.getPoissonSolverName())

                # Load the offAxis correction coefficients
                if (model == "offAxis"):
                    offAxisPolyOrder = self.getOffAxisPolyOrder()
                    I1.setOffAxisCorr(self._inst, offAxisPolyOrder)
                    I2.setOffAxisCorr(self._inst, offAxisPolyOrder)

                # Cocenter the images to the center referenced to fieldX and
                # fieldY. Need to check the availability of this.
                I1.imageCoCenter(self._inst, debugLevel=self.debugLevel)
                I2.imageCoCenter(self._inst, debugLevel=self.debugLevel)

                # Update the self-initial image
                I1.updateImgInit()
                I2.updateImgInit()

            # Initialize the variables used in the iteration.
            self.zcomp = np.zeros(self.getNumOfZernikes())
            self.zc = self.zcomp.copy()

            dimOfDonut = self._inst.getDimOfDonutOnSensor()
            self.wcomp = np.zeros((dimOfDonut, dimOfDonut))
            self.West = self.wcomp.copy()

            self.caustic = False

        # Rename this index (currentItr) for the simplification
        jj = self.currentItr

        # Solve the transport of intensity equation (TIE)
        if (not self.caustic):

            # Reset the images before the compensation
            I1.updateImage(I1.getImgInit().copy())
            I2.updateImage(I2.getImgInit().copy())

            if (compMode == "zer"):

                # Zk coefficient from the previous iteration
                ztmp = self.zc

                # Do the feedback of Zk from the lower terms first based on the
                # sequence defined in compSequence
                if (jj != 0):
                    compSequence = self.getCompSequence()
                    ztmp[int(compSequence[jj - 1]):] = 0

                # Add partial feedback of residual estimated wavefront in Zk
                self.zcomp = self.zcomp + ztmp*feedbackGain

                # Remove the image distortion if the optical model is not
                # "paraxial"
                # Only the optical model of "onAxis" or "offAxis" is considered
                # here
                I1.compensate(self._inst, self, self.zcomp, model)
                I2.compensate(self._inst, self, self.zcomp, model)

            # Check the image condition. If there is the problem, done with
            # this _singleItr().
            if (I1.isCaustic() is True) or (I2.isCaustic() is True):
                self.converge[:, jj] = self.converge[:, jj - 1]
                self.caustic = True
                return

            # Correct the defocal images if I1 and I2 are belong to different
            # sources, which is determined by the (fieldX, field Y)
            I1, I2 = self._applyI1I2pMask(I1, I2)

            # Solve the Poisson's equation
            self.zc, self.West = self._solvePoissonEq(I1, I2, jj)

            # Record/ calculate the Zk coefficient and wavefront
            if (compMode == "zer"):
                self.converge[:, jj] = self.zcomp + self.zc

                xoSensor, yoSensor = self._inst.getSensorCoorAnnular()
                self.wcomp = self.West + ZernikeAnnularEval(
                    np.concatenate(([0, 0, 0], self.zcomp[3:])), xoSensor,
                    yoSensor, self.getObsOfZernikes())

        else:
            # Once we run into caustic, stop here, results may be close to real
            # aberration.
            # Continuation may lead to disatrous results.
            self.converge[:, jj] = self.converge[:, jj - 1]

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = self.converge[3:, jj]*1e9

        # Status of iteration
        stopItr = False

        # Calculate the difference
        if (jj > 0):
            diffZk = np.sum(np.abs(self.converge[:, jj]-self.converge[:, jj-1]))*1e9

            # Check the Status of iteration
            if (diffZk < tol):
                stopItr = True

        # Update the current iteration time
        self.currentItr += 1

        # Show the Zk coefficients in interger in each iteration
        if (self.debugLevel >= 2):
            tmp = self.zer4UpNm
            print("itr = %d, z4-z%d" % (jj, self.getNumOfZernikes()))
            print(np.rint(tmp))

        return stopItr

    def _solvePoissonEq(self, I1, I2, iOutItr=0):
        """Solve the Poisson's equation by Fourier transform (differential) or
        serial expansion (integration).

        There is no convergence for fft actually. Need to add the difference
        comparison and X-alpha method. Need to discuss further for this.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        iOutItr : int, optional
            ith number of outer loop iteration which is important in "fft"
            algorithm. (the default is 0.)

        Returns
        -------
        numpy.ndarray
            Coefficients of normal/ annular Zernike polynomials.
        numpy.ndarray
            Estimated wavefront.
        """

        # Calculate the aperature pixel size
        apertureDiameter = self._inst.getApertureDiameter()
        sensorFactor = self._inst.getSensorFactor()
        dimOfDonut = self._inst.getDimOfDonutOnSensor()
        aperturePixelSize = apertureDiameter*sensorFactor/dimOfDonut

        # Calculate the differential Omega
        dOmega = aperturePixelSize**2

        # Solve the Poisson's equation based on the type of algorithm
        numTerms = self.getNumOfZernikes()
        zobsR = self.getObsOfZernikes()
        PoissonSolver = self.getPoissonSolverName()
        if (PoissonSolver == "fft"):

            # Use the differential method by fft to solve the Poisson's
            # equation

            # Parameter to determine the threshold of calculating I0.
            sumclipSequence = self.getSignalClipSequence()
            cliplevel = sumclipSequence[iOutItr]

            # Generate the v, u-coordinates on pupil plane
            padDim = self.getFftDimension()
            v, u = np.mgrid[
                -0.5/aperturePixelSize: 0.5/aperturePixelSize: 1./padDim/aperturePixelSize,
                -0.5/aperturePixelSize: 0.5/aperturePixelSize: 1./padDim/aperturePixelSize]

            # Show the threshold and pupil coordinate information
            if (self.debugLevel >= 3):
                print("iOuter=%d, cliplevel=%4.2f" % (iOutItr, cliplevel))
                print(v.shape)

            # Calculate the const of fft:
            # FT{Delta W} = -4*pi^2*(u^2+v^2) * FT{W}
            u2v2 = -4 * (np.pi**2) * (u*u + v*v)

            # Set origin to Inf to result in 0 at origin after filtering
            ctrIdx = int(np.floor(padDim/2.0))
            u2v2[ctrIdx, ctrIdx] = np.inf

            # Calculate the wavefront signal
            Sini = self._createSignal(I1, I2, cliplevel)

            # Find the just-outside and just-inside indices of a ring in pixels
            # This is for the use in setting dWdn = 0
            boundaryT = self.getBoundaryThickness()

            struct = generate_binary_structure(2, 1)
            struct = iterate_structure(struct, boundaryT)

            ApringOut = np.logical_xor(binary_dilation(self.pMask, structure=struct),
                                       self.pMask).astype(int)
            ApringIn = np.logical_xor(binary_erosion(self.pMask, structure=struct),
                                      self.pMask).astype(int)

            bordery, borderx = np.nonzero(ApringOut)

            # Put the signal in boundary (since there's no existing Sestimate,
            # S just equals self.S as the initial condition of SCF
            S = Sini.copy()
            for jj in range(self.getNumOfInnerItr()):

                # Calculate FT{S}
                SFFT = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(S)))

                # Calculate W by W=IFT{ FT{S}/(-4*pi^2*(u^2+v^2)) }
                W = np.fft.fftshift(np.fft.irfft2(np.fft.fftshift(SFFT/u2v2), s=S.shape))

                # Estimate the wavefront (includes zeroing offset & masking to
                # the aperture size)

                # Take the estimated wavefront
                West = extractArray(W, dimOfDonut)

                # Calculate the offset
                offset = West[self.pMask == 1].mean()
                West = West - offset
                West[self.pMask == 0] = 0

                # Set dWestimate/dn = 0 around boundary
                WestdWdn0 = West.copy()

                # Do a 3x3 average around each border pixel, including only
                # those pixels inside the aperture
                for ii in range(len(borderx)):
                    reg = West[borderx[ii] - boundaryT:
                               borderx[ii] + boundaryT + 1,
                               bordery[ii] - boundaryT:
                               bordery[ii] + boundaryT + 1]

                    intersectIdx = ApringIn[borderx[ii] - boundaryT:
                                            borderx[ii] + boundaryT + 1,
                                            bordery[ii] - boundaryT:
                                            bordery[ii] + boundaryT + 1]

                    WestdWdn0[borderx[ii], bordery[ii]] = \
                        reg[np.nonzero(intersectIdx)].mean()

                # Take Laplacian to find sensor signal estimate (Delta W = S)
                del2W = laplace(WestdWdn0)/dOmega

                # Extend the dimension of signal to the order of 2 for "fft" to
                # use
                Sest = padArray(del2W, padDim)

                # Put signal back inside boundary, leaving the rest of
                # Sestimate
                Sest[self.pMaskPad == 1] = Sini[self.pMaskPad == 1]

                # Need to recheck this condition
                S = Sest

            # Define the estimated wavefront
            # self.West = West.copy()

            # Calculate the coefficient of normal/ annular Zernike polynomials
            if (self.getCompensatorMode() == "zer"):
                xSensor, ySensor = self._inst.getSensorCoor()
                zc = ZernikeMaskedFit(West, xSensor, ySensor, numTerms,
                                      self.pMask, zobsR)
            else:
                zc = np.zeros(numTerms)

        elif (PoissonSolver == "exp"):

            # Use the integration method by serial expansion to solve the
            # Poisson's equation

            # Calculate I0 and dI
            I0, dI = self._getdIandI(I1, I2)

            # Get the x, y coordinate in mask. The element outside mask is 0.
            xSensor, ySensor = self._inst.getSensorCoor()
            xSensor = xSensor * self.cMask
            ySensor = ySensor * self.cMask

            # Create the F matrix and Zernike-related matrixes
            F = np.zeros(numTerms)
            dZidx = np.zeros((numTerms, dimOfDonut, dimOfDonut))
            dZidy = dZidx.copy()

            zcCol = np.zeros(numTerms)
            for ii in range(int(numTerms)):

                # Calculate the matrix for each Zk related component
                # Set the specific Zk cofficient to be 1 for the calculation
                zcCol[ii] = 1

                F[ii] = np.sum(dI*ZernikeAnnularEval(zcCol, xSensor, ySensor, zobsR))*dOmega
                dZidx[ii, :, :] = ZernikeAnnularGrad(zcCol, xSensor, ySensor, zobsR, "dx")
                dZidy[ii, :, :] = ZernikeAnnularGrad(zcCol, xSensor, ySensor, zobsR, "dy")

                # Set the specific Zk cofficient back to 0 to avoid interfering
                # other Zk's calculation
                zcCol[ii] = 0

            # Calculate Mij matrix, need to check the stability of integration
            # and symmetry later
            Mij = np.zeros([numTerms, numTerms])
            for ii in range(numTerms):
                for jj in range(numTerms):
                    Mij[ii, jj] = np.sum(I0*(dZidx[ii, :, :].squeeze()*dZidx[jj, :, :].squeeze() +
                                             dZidy[ii, :, :].squeeze()*dZidy[jj, :, :].squeeze()))
            Mij = dOmega/(apertureDiameter/2.)**2 * Mij

            # Calculate dz
            focalLength = self._inst.getFocalLength()
            offset = self._inst.getDefocalDisOffset()
            dz = 2*focalLength*(focalLength-offset)/offset

            # Define zc
            zc = np.zeros(numTerms)

            # Consider specific Zk terms only
            idx = (self.getZernikeTerms() - 1).tolist()

            # Solve the equation: M*W = F => W = M^(-1)*F
            zc_tmp = np.linalg.lstsq(Mij[:, idx][idx], F[idx], rcond=None)[0]/dz
            zc[idx] = zc_tmp

            # Estimate the wavefront surface based on z4 - z22
            # z0 - z3 are set to be 0 instead
            West = ZernikeAnnularEval(np.concatenate(([0, 0, 0], zc[3:])),
                                      xSensor, ySensor, zobsR)

        return zc, West

    def _createSignal(self, I1, I2, cliplevel):
        """Calculate the wavefront singal for "fft" to use in solving the
        Poisson's equation.

        Need to discuss the method to define threshold and discuss to use
        np.median() instead.
        Need to discuss why the calculation of I0 is different from "exp".

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        cliplevel : float
            Parameter to determine the threshold of calculating I0.

        Returns
        -------
        numpy.ndarray
            Approximated wavefront signal.
        """

        # Check the condition of images
        I1image, I2image = self._checkImageDim(I1, I2)

        # Wavefront signal S=-(1/I0)*(dI/dz) is approximated to be
        # -(1/delta z)*(I1-I2)/(I1+I2)
        num = I1image - I2image
        den = I1image + I2image

        # Define the effective minimum central signal element by the threshold
        # ( I0=(I1+I2)/2 )

        # Calculate the threshold
        pixelList = den * self.cMask
        pixelList = pixelList[pixelList != 0]

        low = pixelList.min()
        high = pixelList.max()
        medianThreshold = (high-low)/2. + low

        # Define the effective minimum central signal element
        den[den < medianThreshold*cliplevel] = 1.5*medianThreshold

        # Calculate delta z = f(f-l)/l, f: focal length, l: defocus distance of
        # the image planes
        focalLength = self._inst.getFocalLength()
        offset = self._inst.getDefocalDisOffset()
        deltaZ = focalLength*(focalLength-offset)/offset

        # Calculate the wavefront signal. Enforce the element outside the mask
        # to be 0.
        den[den == 0] = np.inf

        # Calculate the wavefront signal
        S = num/den/deltaZ

        # Extend the dimension of signal to the order of 2 for "fft" to use
        padDim = self.getFftDimension()
        Sout = padArray(S, padDim)*self.cMaskPad

        return Sout

    def _getdIandI(self, I1, I2):
        """Calculate the central image and differential image to be used in the
        serial expansion method.

        It is noted that the images are assumed to be co-center already. And
        the intra-/ extra-focal image can overlap with one another after the
        rotation of 180 degree.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            Image data of I0.
        numpy.ndarray
            Differential image (dI) of I0.
        """

        # Check the condition of images
        I1image, I2image = self._checkImageDim(I1, I2)

        # Calculate the central image and differential iamge
        I0 = (I1image+I2image)/2
        dI = I2image-I1image

        return I0, dI

    def _checkImageDim(self, I1, I2):
        """Check the dimension of images.

        It is noted that the I2 image is rotated by 180 degree.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            I1 defocal image.
        numpy.ndarray
            I2 defocal image. It is noted that the I2 image is rotated by 180
            degree.

        Raises
        ------
        Exception
            Check the dimension of images is n by n or not.
        Exception
            Check two defocal images have the same size or not.
        """

        # Check the condition of images
        m1, n1 = I1.getImg().shape
        m2, n2 = I2.getImg().shape

        if (m1 != n1 or m2 != n2):
            raise Exception("Image is not square.")

        if (m1 != m2 or n1 != n2):
            raise Exception("Images do not have the same size.")

        # Define I1
        I1image = I1.getImg()

        # Rotate the image by 180 degree through rotating two times of 90
        # degree
        I2image = np.rot90(I2.getImg(), k=2)

        return I1image, I2image

    def _makeMasterMask(self, I1, I2, poissonSolver=None):
        """Calculate the common mask of defocal images.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        poissonSolver : str, optional
            Algorithm to solve the Poisson's equation. If the "fft" is used,
            the mask dimension will be extended to the order of 2 for the "fft"
            to use. (the default is None.)
        """

        # Get the overlap region of mask for intra- and extra-focal images.
        # This is to avoid the anormalous signal due to difference in
        # vignetting.
        self.pMask = I1.getPaddedMask() * I2.getPaddedMask()
        self.cMask = I1.getNonPaddedMask() * I2.getNonPaddedMask()

        # Change the dimension of image for fft to use
        if (poissonSolver == "fft"):
            padDim = self.getFftDimension()
            self.pMaskPad = padArray(self.pMask, padDim)
            self.cMaskPad = padArray(self.cMask, padDim)

    def _applyI1I2pMask(self, I1, I2):
        """Correct the defocal images if I1 and I2 are belong to different
        sources.

        (There is a problem for this actually. If I1 and I2 come from different
        sources, what should the correction of TIE be? At this moment, the
        fieldX and fieldY of I1 and I2 should be different. And the sources are
        different also.)

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            Corrected I1 image.
        numpy.ndarray
            Corrected I2 image.
        """

        # Get the overlap region of images and do the normalization.
        if (I1.fieldX != I2.fieldX or I1.fieldY != I2.fieldY):

            # Get the overlap region of image
            I1.updateImage(I1.getImg()*self.pMask)

            # Rotate the image by 180 degree through rotating two times of 90
            # degree
            I2.updateImage(I2.getImg()*np.rot90(self.pMask, 2))

            # Do the normalization of image.
            I1.updateImage(I1.getImg()/np.sum(I1.getImg()))
            I2.updateImage(I2.getImg()/np.sum(I2.getImg()))

        # Return the correct images. It is noted that there is no need of
        # vignetting correction.
        # This is after masking already in _singleItr() or itr0().
        return I1, I2

    def _reset(self, I1, I2):
        """Reset the iteration time of outer loop and defocal images.

        Parameters
        ----------
        I1 : Image
            Intra- or extra-focal image.
        I2 : Image
            Intra- or extra-focal image.
        """

        # Reset the current iteration time to 0
        self.currentItr = 0

        # Show the reset information
        if (self.debugLevel >= 3):
            print("Resetting images: I1 and I2")

        # Determine to reset the images or not based on the existence of
        # the attribute: Image.image0. Only after the first run of
        # inner loop, this attribute will exist.
        try:
            # Reset the images to the first beginning
            I1.updateImage(I1.getImgInit().copy())
            I2.updateImage(I2.getImgInit().copy())

            # Show the information of resetting image
            if (self.debugLevel >= 3):
                print("Resetting images in inside.")

        except AttributeError:
            # Show the information of no image0
            if (self.debugLevel >= 3):
                print("Image0 = None. This is the first time to run the code.")

            pass

    def outZer4Up(self, unit="nm", filename=None, showPlot=False):
        """Put the coefficients of normal/ annular Zernike polynomials on
        terminal or file ande show the image if it is needed.

        Parameters
        ----------
        unit : str, optional
            Unit of the coefficients of normal/ annular Zernike polynomials. It
            can be m, nm, or um. (the default is "nm".)
        filename : str, optional
            Name of output file. (the default is None.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)
        """

        # List of Zn,m
        Znm = ["Z0,0", "Z1,1", "Z1,-1", "Z2,0", "Z2,-2", "Z2,2", "Z3,-1",
               "Z3,1", "Z3,-3", "Z3,3", "Z4,0", "Z4,2", "Z4,-2", "Z4,4",
               "Z4,-4", "Z5,1", "Z5,-1", "Z5,3", "Z5,-3", "Z5,5", "Z5,-5",
               "Z6,0"]

        # Decide the format of z based on the input unit (m, nm, or um)
        if (unit == "m"):
            z = self.zer4UpNm*1e-9
        elif (unit == "nm"):
            z = self.zer4UpNm
        elif (unit == "um"):
            z = self.zer4UpNm*1e-3
        else:
            print("Unknown unit: %s" % unit)
            print("Unit options are: m, nm, um")
            return

        # Write the coefficients into a file if needed.
        if (filename is not None):
            f = open(filename, "w")
        else:
            f = sys.stdout

        for ii in range(4, len(z)+4):
            f.write("Z%d (%s)\t %8.3f\n" % (ii, Znm[ii-1], z[ii-4]))

        # Close the file
        if (filename is not None):
            f.close()

        # Show the plot
        if (showPlot):
            plt.figure()

            x = range(4, len(z) + 4)
            plt.plot(x, z, marker="o", color="r", markersize=10)
            plt.xlabel("Zernike Index")
            plt.ylabel("Zernike coefficient (%s)" % unit)
            plt.grid()
            plt.show()
Пример #23
0
    def _runWEP(
        self,
        instDir,
        algoFolderPath,
        useAlgorithm,
        imageFolderPath,
        intra_image_name,
        extra_image_name,
        fieldXY,
        opticalModel,
        showFig=False,
    ):

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # There is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        I1 = CompensableImage()
        I2 = CompensableImage()

        I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set the instrument
        inst = Instrument(instDir)
        inst.config(CamType.LsstCam,
                    I1.getImgSizeInPix(),
                    announcedDefocalDisInMm=1.0)

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, inst, debugLevel=0)

        # Plot the original wavefront images
        if showFig:
            plotImage(I1.image, title="intra image")
            plotImage(I2.image, title="extra image")

        # Run it
        algo.runIt(I1, I2, opticalModel, tol=1e-3)

        # Show the Zernikes Zn (n>=4)
        algo.outZer4Up(showPlot=False)

        # Plot the final conservated images and wavefront
        if showFig:
            plotImage(I1.image, title="Compensated intra image")
            plotImage(I2.image, title="Compensated extra image")

            # Plot the Wavefront
            plotImage(algo.wcomp, title="Final wavefront")
            plotImage(
                algo.wcomp,
                title="Final wavefront with pupil mask applied",
                mask=algo.pMask,
            )

        # Return the Zernikes Zn (n>=4)
        return algo.getZer4UpInNm()
Пример #24
0
class WfEstimator(object):
    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensableImage
            Intra-focal donut image.
        """

        return self.imgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensableImage
            Extra-focal donut image.
        """

        return self.imgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(
        self,
        solver="exp",
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        defocalDisInMm=1.5,
        sizeInPix=120,
        centroidFindType=CentroidFindType.RandomWalk,
        debugLevel=0,
    ):
        """Configure the TIE solver.

        Parameters
        ----------
        solver : str, optional
            Algorithm to solve the Poisson's equation in the transport of
            intensity equation (TIE). It can be "fft" or "exp" here. (the
            default is "exp".)
        camType : enum 'CamType', optional
            Camera type. (the default is CamType.LsstCam.)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis". (the
            default is "offAxis".)
        defocalDisInMm : float, optional
            Defocal distance in mm. (the default is 1.5.)
        sizeInPix : int, optional
            Wavefront image pixel size. (the default is 120.)
        centroidFindType : enum 'CentroidFindType', optional
            Algorithm to find the centroid of donut. (the default is
            CentroidFindType.RandomWalk.)
        debugLevel : int, optional
            Show the information under the running. If the value is higher,
            the information shows more. It can be 0, 1, 2, or 3. (the default
            is 0.)

        Raises
        ------
        ValueError
            Wrong Poisson solver name.
        ValueError
            Wrong optical model.
        """

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

        # Update the instrument name
        self.sizeInPix = int(sizeInPix)
        self.inst.config(camType,
                         self.sizeInPix,
                         announcedDefocalDisInMm=defocalDisInMm)

        self.algo.config(solver, self.inst, debugLevel=debugLevel)

        # Reset the centroid find algorithm if not the default one
        if centroidFindType != CentroidFindType.RandomWalk:
            self.imgIntra = CompensableImage(centroidFindType=centroidFindType)
            self.imgExtra = CompensableImage(centroidFindType=centroidFindType)

    def setImg(self, fieldXY, defocalType, image=None, imageFile=None):
        """Set the wavefront image.

        Parameters
        ----------
        fieldXY : tuple or list
            Position of donut on the focal plane in degree for intra- and
            extra-focal images.
        defocalType : enum 'DefocalType'
            Defocal type of image.
        image : numpy.ndarray, optional
            Array of image. (the default is None.)
        imageFile : str, optional
            Path of image file. (the default is None.)
        """

        if defocalType == DefocalType.Intra:
            img = self.imgIntra
        elif defocalType == DefocalType.Extra:
            img = self.imgExtra

        img.setImg(fieldXY, defocalType, image=image, imageFile=imageFile)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """Calculate the wavefront error.

        Parameters
        ----------
        tol : float, optional
            [description] (the default is 1e-3.)
        showZer : bool, optional
            Decide to show the annular Zernike polynomails or not. (the default
            is False.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)

        Returns
        -------
        numpy.ndarray
            Coefficients of Zernike polynomials (z4 - z22).

        Raises
        ------
        RuntimeError
            Input image shape is wrong.
        """

        # Check the image size
        for img in (self.imgIntra, self.imgExtra):
            d1, d2 = img.getImg().shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError(
                    "Input image shape is (%d, %d), not required (%d, %d)" %
                    (d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.imgIntra,
                        self.imgExtra,
                        self.opticalModel,
                        tol=tol)

        # Show the Zernikes Zn (n>=4)
        if showZer:
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.getZer4UpInNm()
Пример #25
0
class TestInstrument(unittest.TestCase):
    """Test the Instrument class."""

    def setUp(self):

        self.instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        self.inst = Instrument(self.instDir)
        self.dimOfDonutOnSensor = 120

        self.inst.config(CamType.LsstCam, self.dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.5)

    def testConfigWithUnsupportedCamType(self):

        self.assertRaises(ValueError, self.inst.config, CamType.LsstFamCam, 120)

    def testGetInstFileDir(self):

        instFileDir = self.inst.getInstFileDir()

        ansInstFileDir = os.path.join(self.instDir, "lsst")
        self.assertEqual(instFileDir, ansInstFileDir)

    def testGetAnnDefocalDisInMm(self):

        annDefocalDisInMm = self.inst.getAnnDefocalDisInMm()
        self.assertEqual(annDefocalDisInMm, 1.5)

    def testSetAnnDefocalDisInMm(self):

        annDefocalDisInMm = 2.0
        self.inst.setAnnDefocalDisInMm(annDefocalDisInMm)

        self.assertEqual(self.inst.getAnnDefocalDisInMm(), annDefocalDisInMm)

    def testGetInstFilePath(self):

        instFilePath = self.inst.getInstFilePath()
        self.assertTrue(os.path.exists(instFilePath))
        self.assertEqual(os.path.basename(instFilePath), "instParam.yaml")

    def testGetMaskOffAxisCorr(self):

        maskOffAxisCorr = self.inst.getMaskOffAxisCorr()
        self.assertEqual(maskOffAxisCorr.shape, (9, 5))
        self.assertEqual(maskOffAxisCorr[0, 0], 1.07)
        self.assertEqual(maskOffAxisCorr[2, 3], -0.090100858)

    def testGetDimOfDonutOnSensor(self):

        dimOfDonutOnSensor = self.inst.getDimOfDonutOnSensor()
        self.assertEqual(dimOfDonutOnSensor, self.dimOfDonutOnSensor)

    def testGetObscuration(self):

        obscuration = self.inst.getObscuration()
        self.assertEqual(obscuration, 0.61)

    def testGetFocalLength(self):

        focalLength = self.inst.getFocalLength()
        self.assertEqual(focalLength, 10.312)

    def testGetApertureDiameter(self):

        apertureDiameter = self.inst.getApertureDiameter()
        self.assertEqual(apertureDiameter, 8.36)

    def testGetDefocalDisOffset(self):

        defocalDisInM = self.inst.getDefocalDisOffset()

        # The answer is 1.5 mm
        self.assertEqual(defocalDisInM * 1e3, 1.5)

    def testGetCamPixelSize(self):

        camPixelSizeInM = self.inst.getCamPixelSize()

        # The answer is 10 um
        self.assertEqual(camPixelSizeInM * 1e6, 10)

    def testGetMarginalFocalLength(self):

        marginalFL = self.inst.getMarginalFocalLength()
        self.assertAlmostEqual(marginalFL, 9.4268, places=4)

    def testGetSensorFactor(self):

        sensorFactor = self.inst.getSensorFactor()
        self.assertAlmostEqual(sensorFactor, 0.98679, places=5)

    def testGetSensorCoor(self):

        xSensor, ySensor = self.inst.getSensorCoor()
        self.assertEqual(xSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(xSensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(xSensor[0, 1], -0.96212, places=5)

        self.assertEqual(ySensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertAlmostEqual(ySensor[0, 0], -0.97857, places=5)
        self.assertAlmostEqual(ySensor[1, 0], -0.96212, places=5)

    def testGetSensorCoorAnnular(self):

        xoSensor, yoSensor = self.inst.getSensorCoorAnnular()
        self.assertEqual(xoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(xoSensor[0, 0]))
        self.assertTrue(np.isnan(xoSensor[60, 60]))

        self.assertEqual(yoSensor.shape,
                         (self.dimOfDonutOnSensor, self.dimOfDonutOnSensor))
        self.assertTrue(np.isnan(yoSensor[0, 0]))
        self.assertTrue(np.isnan(yoSensor[60, 60]))
Пример #26
0
class Algorithm(object):
    def __init__(self, algoDir):
        """Initialize the Algorithm class.

        Algorithm used to solve the transport of intensity equation to get
        normal/ annular Zernike polynomials.

        Parameters
        ----------
        algoDir : str
            Algorithm configuration directory.
        """

        self.algoDir = algoDir
        self.algoParamFile = ParamReader()

        self._inst = Instrument("")

        # Show the calculation message based on this value
        # 0 means no message will be showed
        self.debugLevel = 0

        # Image has the problem or not from the over-compensation
        self.caustic = False

        # Record the Zk coefficients in each outer-loop iteration
        # The actual total outer-loop iteration time is Num_of_outer_itr + 1
        self.converge = np.array([])

        # Current number of outer-loop iteration
        self.currentItr = 0

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = np.array([])

        # Converged wavefront.
        self.wcomp = np.array([])

        # Calculated wavefront in previous outer-loop iteration.
        self.West = np.array([])

        # Converged Zk coefficients
        self.zcomp = np.array([])

        # Calculated Zk coefficients in previous outer-loop iteration
        self.zc = np.array([])

        # Padded mask for use at the offset planes
        self.pMask = None

        # Non-padded mask corresponding to aperture
        self.cMask = None

        # Change the dimension of mask for fft to use
        self.pMaskPad = None
        self.cMaskPad = None

    def reset(self):
        """Reset the calculation for the new input images with the same
        algorithm settings."""

        self.caustic = False
        self.converge = np.zeros(self.converge.shape)
        self.currentItr = 0
        self.zer4UpNm = np.zeros(self.zer4UpNm.shape)

        self.wcomp = np.zeros(self.wcomp.shape)
        self.West = np.zeros(self.West.shape)

        self.zcomp = np.zeros(self.zcomp.shape)
        self.zc = np.zeros(self.zc.shape)

        self.pMask = None
        self.cMask = None

        self.pMaskPad = None
        self.cMaskPad = None

    def config(self, algoName, inst, debugLevel=0):
        """Configure the algorithm to solve TIE.

        Parameters
        ----------
        algoName : str
            Algorithm configuration file to solve the Poisson's equation in the
            transport of intensity equation (TIE). It can be "fft" or "exp"
            here.
        inst : Instrument
            Instrument to use.
        debugLevel : int, optional
            Show the information under the running. If the value is higher, the
            information shows more. It can be 0, 1, 2, or 3. (the default is
            0.)
        """

        algoParamFilePath = os.path.join(self.algoDir, "%s.yaml" % algoName)
        self.algoParamFile.setFilePath(algoParamFilePath)

        self._inst = inst
        self.debugLevel = debugLevel

        self.caustic = False

        numTerms = self.getNumOfZernikes()
        outerItr = self.getNumOfOuterItr()
        self.converge = np.zeros((numTerms, outerItr + 1))

        self.currentItr = 0

        self.zer4UpNm = np.zeros(numTerms - 3)

        # Wavefront related parameters
        dimOfDonut = self._inst.getDimOfDonutOnSensor()
        self.wcomp = np.zeros((dimOfDonut, dimOfDonut))
        self.West = self.wcomp.copy()

        # Used in model basis ("zer").
        self.zcomp = np.zeros(numTerms)
        self.zc = self.zcomp.copy()

        # Mask related variables
        self.pMask = None
        self.cMask = None
        self.pMaskPad = None
        self.cMaskPad = None

    def setDebugLevel(self, debugLevel):
        """Set the debug level.

        If the value is higher, the information shows more. It can be 0, 1, 2,
        or 3.

        Parameters
        ----------
        debugLevel : int
            Show the information under the running.
        """

        self.debugLevel = int(debugLevel)

    def getDebugLevel(self):
        """Get the debug level.

        If the value is higher, the information shows more. It can be 0, 1, 2,
        or 3.

        Returns
        -------
        int
            Debug level.
        """

        return self.debugLevel

    def getZer4UpInNm(self):
        """Get the coefficients of Zernike polynomials of z4-zn in nm.

        Returns
        -------
        numpy.ndarray
            Zernike polynomials of z4-zn in nm.
        """

        return self.zer4UpNm

    def getPoissonSolverName(self):
        """Get the method name to solve the Poisson equation.

        Returns
        -------
        str
            Method name to solve the Poisson equation.
        """

        return self.algoParamFile.getSetting("poissonSolver")

    def getNumOfZernikes(self):
        """Get the maximum number of Zernike polynomials supported.

        Returns
        -------
        int
            Maximum number of Zernike polynomials supported.
        """

        return int(self.algoParamFile.getSetting("numOfZernikes"))

    def getZernikeTerms(self):
        """Get the Zernike terms in using.

        Returns
        -------
        list[int]
            Zernike terms in using.
        """

        numTerms = self.getNumOfZernikes()

        return list(range(numTerms))

    def getObsOfZernikes(self):
        """Get the obscuration of annular Zernike polynomials.

        Returns
        -------
        float
            Obscuration of annular Zernike polynomials
        """

        zobsR = self.algoParamFile.getSetting("obsOfZernikes")
        if zobsR == 1:
            zobsR = self._inst.getObscuration()

        return float(zobsR)

    def getNumOfOuterItr(self):
        """Get the number of outer loop iteration.

        Returns
        -------
        int
            Number of outer loop iteration.
        """

        return int(self.algoParamFile.getSetting("numOfOuterItr"))

    def getNumOfInnerItr(self):
        """Get the number of inner loop iteration.

        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        int
            Number of inner loop iteration.
        """

        return int(self.algoParamFile.getSetting("numOfInnerItr"))

    def getFeedbackGain(self):
        """Get the gain value used in the outer loop iteration.

        Returns
        -------
        float
            Gain value used in the outer loop iteration.
        """

        return self.algoParamFile.getSetting("feedbackGain")

    def getOffAxisPolyOrder(self):
        """Get the number of polynomial order supported in off-axis correction.

        Returns
        -------
        int
            Number of polynomial order supported in off-axis correction.
        """

        return int(self.algoParamFile.getSetting("offAxisPolyOrder"))

    def getCompensatorMode(self):
        """Get the method name to compensate the wavefront by wavefront error.

        Returns
        -------
        str
            Method name to compensate the wavefront by wavefront error.
        """

        return self.algoParamFile.getSetting("compensatorMode")

    def getCompSequence(self):
        """Get the compensated sequence of Zernike order for each iteration.

        Returns
        -------
        numpy.ndarray[int]
            Compensated sequence of Zernike order for each iteration.
        """

        compSequenceFromFile = self.algoParamFile.getSetting("compSequence")
        compSequence = np.array(compSequenceFromFile, dtype=int)

        # If outerItr is large, and compSequence is too small,
        # the rest in compSequence will be filled.
        # This is used in the "zer" method.
        outerItr = self.getNumOfOuterItr()
        compSequence = self._extend1dArray(compSequence, outerItr)
        compSequence = compSequence.astype(int)

        return compSequence

    def _extend1dArray(self, origArray, targetLength):
        """Extend the 1D original array to the taget length.

        The extended value will be the final element of original array. Nothing
        will be done if the input array is not 1D or its length is less than
        the target.

        Parameters
        ----------
        origArray : numpy.ndarray
            Original array with 1 dimension.
        targetLength : int
            Target length of new extended array.

        Returns
        -------
        numpy.ndarray
            Extended 1D array.
        """

        if (len(origArray) < targetLength) and (origArray.ndim == 1):
            leftOver = np.ones(targetLength - len(origArray))
            extendArray = np.append(origArray, origArray[-1] * leftOver)
        else:
            extendArray = origArray

        return extendArray

    def getBoundaryThickness(self):
        """Get the boundary thickness that the computation mask extends beyond
        the pupil mask.

        It is noted that in Fast Fourier transform (FFT) algorithm, it is also
        the width of Neuman boundary where the derivative of the wavefront is
        set to zero

        Returns
        -------
        int
            Boundary thickness.
        """

        return int(self.algoParamFile.getSetting("boundaryThickness"))

    def getFftDimension(self):
        """Get the FFT pad dimension in pixel.

        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        int
            FFT pad dimention.
        """

        fftDim = int(self.algoParamFile.getSetting("fftDimension"))

        # Make sure the dimension is the order of multiple of 2
        if fftDim == 999:
            dimToFit = self._inst.getDimOfDonutOnSensor()
        else:
            dimToFit = fftDim

        padDim = int(2 ** np.ceil(np.log2(dimToFit)))

        return padDim

    def getSignalClipSequence(self):
        """Get the signal clip sequence.

        The number of values should be the number of compensation plus 1.
        This is for the fast Fourier transform (FFT) solver only.

        Returns
        -------
        numpy.ndarray
            Signal clip sequence.
        """

        sumclipSequenceFromFile = self.algoParamFile.getSetting("signalClipSequence")
        sumclipSequence = np.array(sumclipSequenceFromFile)

        # If outerItr is large, and sumclipSequence is too small, the rest in
        # sumclipSequence will be filled.
        # This is used in the "zer" method.
        targetLength = self.getNumOfOuterItr() + 1
        sumclipSequence = self._extend1dArray(sumclipSequence, targetLength)

        return sumclipSequence

    def getMaskScalingFactor(self):
        """Get the mask scaling factor for fast beam.

        Returns
        -------
        float
            Mask scaling factor for fast beam.
        """

        # m = R'*f/(l*R), R': radius of the no-aberration image
        focalLength = self._inst.getFocalLength()
        marginalFL = self._inst.getMarginalFocalLength()
        maskScalingFactor = focalLength / marginalFL

        return maskScalingFactor

    def getWavefrontMapEsti(self):
        """Get the estimated wavefront map.

        Returns
        -------
        numpy.ndarray
            Estimated wavefront map.
        """

        return self._getWavefrontMapWithMaskApplied(self.wcomp)

    def getWavefrontMapResidual(self):
        """Get the residual wavefront map.

        Returns
        -------
        numpy.ndarray
            Residual wavefront map.
        """

        return self._getWavefrontMapWithMaskApplied(self.West)

    def _getWavefrontMapWithMaskApplied(self, wfMap):
        """Get the wavefront map with mask applied.

        Parameters
        ----------
        wfMap : numpy.ndarray
            Wavefront map.

        Returns
        -------
        numpy.ndarray
            Wavefront map with mask applied.
        """

        self._checkNotItr0()

        wfMapWithMask = wfMap.copy()
        wfMapWithMask[self.pMask == 0] = np.nan

        return wfMapWithMask

    def _checkNotItr0(self):
        """Check not in the iteration 0.

        TIE: Transport of intensity equation.

        Raises
        ------
        RuntimeError
            Need to solve the TIE first.
        """

        if self.currentItr == 0:
            raise RuntimeError("Need to solve the TIE first.")

    def itr0(self, I1, I2, model):
        """Calculate the wavefront and coefficients of normal/ annular Zernike
        polynomials in the first iteration time.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        """

        # Reset the iteration time of outer loop and decide to reset the
        # defocal images or not
        self._reset(I1, I2)

        # Solve the transport of intensity equation (TIE)
        self._singleItr(I1, I2, model)

    def runIt(self, I1, I2, model, tol=1e-3):
        """Calculate the wavefront error by solving the transport of intensity
        equation (TIE).

        The inner (for fft algorithm) and outer loops are used. The inner loop
        is to solve the Poisson's equation. The outer loop is to compensate the
        intra- and extra-focal images to mitigate the calculation of wavefront
        (e.g. S = -1/(delta Z) * (I1 - I2)/ (I1 + I2)).

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        tol : float, optional
            Tolerance of difference of coefficients of Zk polynomials compared
            with the previours iteration. (the default is 1e-3.)
        """

        # To have the iteration time initiated from global variable is to
        # distinguish the manually and automatically iteration processes.
        itr = self.currentItr
        while itr <= self.getNumOfOuterItr():
            stopItr = self._singleItr(I1, I2, model, tol)

            # Stop the iteration of outer loop if converged
            if stopItr:
                break

            itr += 1

    def nextItr(self, I1, I2, model, nItr=1):
        """Run the outer loop iteration with the specific time defined in nItr.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        nItr : int, optional
            Outer loop iteration time. (the default is 1.)
        """

        #  Do the iteration
        ii = 0
        while ii < nItr:
            self._singleItr(I1, I2, model)
            ii += 1

    def _singleItr(self, I1, I2, model, tol=1e-3):
        """Run the outer-loop with single iteration to solve the transport of
        intensity equation (TIE).

        This is to compensate the approximation of wavefront:
        S = -1/(delta Z) * (I1 - I2)/ (I1 + I2)).

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        model : str
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
        tol : float, optional
            Tolerance of difference of coefficients of Zk polynomials compared
            with the previours iteration. (the default is 1e-3.)

        Returns
        -------
        bool
            Status of iteration.
        """

        # Use the zonal mode ("zer")
        compMode = self.getCompensatorMode()

        # Define the gain of feedbackGain
        feedbackGain = self.getFeedbackGain()

        # Set the pre-condition
        if self.currentItr == 0:

            # Check this is the first time of running iteration or not
            if I1.getImgInit() is None or I2.getImgInit() is None:

                # Check the image dimension
                if I1.getImg().shape != I2.getImg().shape:
                    print(
                        "Error: The intra and extra image stamps need to be of same size."
                    )
                    sys.exit()

                # Calculate the pupil mask (binary matrix) and related
                # parameters
                boundaryT = self.getBoundaryThickness()
                I1.makeMask(self._inst, model, boundaryT, 1)
                I2.makeMask(self._inst, model, boundaryT, 1)
                self._makeMasterMask(I1, I2, self.getPoissonSolverName())

                # Load the offAxis correction coefficients
                if model == "offAxis":
                    offAxisPolyOrder = self.getOffAxisPolyOrder()
                    I1.setOffAxisCorr(self._inst, offAxisPolyOrder)
                    I2.setOffAxisCorr(self._inst, offAxisPolyOrder)

                # Cocenter the images to the center referenced to fieldX and
                # fieldY. Need to check the availability of this.
                I1.imageCoCenter(self._inst, debugLevel=self.debugLevel)
                I2.imageCoCenter(self._inst, debugLevel=self.debugLevel)

                # Update the self-initial image
                I1.updateImgInit()
                I2.updateImgInit()

            # Initialize the variables used in the iteration.
            self.zcomp = np.zeros(self.getNumOfZernikes())
            self.zc = self.zcomp.copy()

            dimOfDonut = self._inst.getDimOfDonutOnSensor()
            self.wcomp = np.zeros((dimOfDonut, dimOfDonut))
            self.West = self.wcomp.copy()

            self.caustic = False

        # Rename this index (currentItr) for the simplification
        jj = self.currentItr

        # Solve the transport of intensity equation (TIE)
        if not self.caustic:

            # Reset the images before the compensation
            I1.updateImage(I1.getImgInit().copy())
            I2.updateImage(I2.getImgInit().copy())

            if compMode == "zer":

                # Zk coefficient from the previous iteration
                ztmp = self.zc.copy()

                # Do the feedback of Zk from the lower terms first based on the
                # sequence defined in compSequence
                if jj != 0:
                    compSequence = self.getCompSequence()
                    ztmp[int(compSequence[jj - 1]) :] = 0

                # Add partial feedback of residual estimated wavefront in Zk
                self.zcomp = self.zcomp + ztmp * feedbackGain

                # Remove the image distortion by forwarding the image to pupil
                I1.compensate(self._inst, self, self.zcomp, model)
                I2.compensate(self._inst, self, self.zcomp, model)

            # Check the image condition. If there is the problem, done with
            # this _singleItr().
            if (I1.isCaustic() is True) or (I2.isCaustic() is True):
                self.converge[:, jj] = self.converge[:, jj - 1]
                self.caustic = True
                return

            # Correct the defocal images if I1 and I2 are belong to different
            # sources, which is determined by the (fieldX, field Y)
            I1, I2 = self._applyI1I2pMask(I1, I2)

            # Solve the Poisson's equation
            self.zc, self.West = self._solvePoissonEq(I1, I2, jj)

            # Record/ calculate the Zk coefficient and wavefront
            if compMode == "zer":
                self.converge[:, jj] = self.zcomp + self.zc

                xoSensor, yoSensor = self._inst.getSensorCoorAnnular()
                self.wcomp = self.West + ZernikeAnnularEval(
                    np.concatenate(([0, 0, 0], self.zcomp[3:])),
                    xoSensor,
                    yoSensor,
                    self.getObsOfZernikes(),
                )

        else:
            # Once we run into caustic, stop here, results may be close to real
            # aberration.
            # Continuation may lead to disatrous results.
            self.converge[:, jj] = self.converge[:, jj - 1]

        # Record the coefficients of normal/ annular Zernike polynomials after
        # z4 in unit of nm
        self.zer4UpNm = self.converge[3:, jj] * 1e9

        # Status of iteration
        stopItr = False

        # Calculate the difference
        if jj > 0:
            diffZk = (
                np.sum(np.abs(self.converge[:, jj] - self.converge[:, jj - 1])) * 1e9
            )

            # Check the Status of iteration
            if diffZk < tol:
                stopItr = True

        # Update the current iteration time
        self.currentItr += 1

        # Show the Zk coefficients in interger in each iteration
        if self.debugLevel >= 2:
            print("itr = %d, z4-z%d" % (jj, self.getNumOfZernikes()))
            print(np.rint(self.zer4UpNm))

        return stopItr

    def _solvePoissonEq(self, I1, I2, iOutItr=0):
        """Solve the Poisson's equation by Fourier transform (differential) or
        serial expansion (integration).

        There is no convergence for fft actually. Need to add the difference
        comparison and X-alpha method. Need to discuss further for this.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        iOutItr : int, optional
            ith number of outer loop iteration which is important in "fft"
            algorithm. (the default is 0.)

        Returns
        -------
        numpy.ndarray
            Coefficients of normal/ annular Zernike polynomials.
        numpy.ndarray
            Estimated wavefront.
        """

        # Calculate the aperature pixel size
        apertureDiameter = self._inst.getApertureDiameter()
        sensorFactor = self._inst.getSensorFactor()
        dimOfDonut = self._inst.getDimOfDonutOnSensor()
        aperturePixelSize = apertureDiameter * sensorFactor / dimOfDonut

        # Calculate the differential Omega
        dOmega = aperturePixelSize ** 2

        # Solve the Poisson's equation based on the type of algorithm
        numTerms = self.getNumOfZernikes()
        zobsR = self.getObsOfZernikes()
        PoissonSolver = self.getPoissonSolverName()
        if PoissonSolver == "fft":

            # Use the differential method by fft to solve the Poisson's
            # equation

            # Parameter to determine the threshold of calculating I0.
            sumclipSequence = self.getSignalClipSequence()
            cliplevel = sumclipSequence[iOutItr]

            # Generate the v, u-coordinates on pupil plane
            padDim = self.getFftDimension()
            v, u = np.mgrid[
                -0.5
                / aperturePixelSize : 0.5
                / aperturePixelSize : 1.0
                / padDim
                / aperturePixelSize,
                -0.5
                / aperturePixelSize : 0.5
                / aperturePixelSize : 1.0
                / padDim
                / aperturePixelSize,
            ]

            # Show the threshold and pupil coordinate information
            if self.debugLevel >= 3:
                print("iOuter=%d, cliplevel=%4.2f" % (iOutItr, cliplevel))
                print(v.shape)

            # Calculate the const of fft:
            # FT{Delta W} = -4*pi^2*(u^2+v^2) * FT{W}
            u2v2 = -4 * (np.pi ** 2) * (u * u + v * v)

            # Set origin to Inf to result in 0 at origin after filtering
            ctrIdx = int(np.floor(padDim / 2.0))
            u2v2[ctrIdx, ctrIdx] = np.inf

            # Calculate the wavefront signal
            Sini = self._createSignal(I1, I2, cliplevel)

            # Find the just-outside and just-inside indices of a ring in pixels
            # This is for the use in setting dWdn = 0
            boundaryT = self.getBoundaryThickness()

            struct = generate_binary_structure(2, 1)
            struct = iterate_structure(struct, boundaryT)

            ApringOut = np.logical_xor(
                binary_dilation(self.pMask, structure=struct), self.pMask
            ).astype(int)
            ApringIn = np.logical_xor(
                binary_erosion(self.pMask, structure=struct), self.pMask
            ).astype(int)

            bordery, borderx = np.nonzero(ApringOut)

            # Put the signal in boundary (since there's no existing Sestimate,
            # S just equals self.S as the initial condition of SCF
            S = Sini.copy()
            for jj in range(self.getNumOfInnerItr()):

                # Calculate FT{S}
                SFFT = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(S)))

                # Calculate W by W=IFT{ FT{S}/(-4*pi^2*(u^2+v^2)) }
                W = np.fft.fftshift(
                    np.fft.irfft2(np.fft.fftshift(SFFT / u2v2), s=S.shape)
                )

                # Estimate the wavefront (includes zeroing offset & masking to
                # the aperture size)

                # Take the estimated wavefront
                West = extractArray(W, dimOfDonut)

                # Calculate the offset
                offset = West[self.pMask == 1].mean()
                West = West - offset
                West[self.pMask == 0] = 0

                # Set dWestimate/dn = 0 around boundary
                WestdWdn0 = West.copy()

                # Do a 3x3 average around each border pixel, including only
                # those pixels inside the aperture
                for ii in range(len(borderx)):
                    reg = West[
                        borderx[ii] - boundaryT : borderx[ii] + boundaryT + 1,
                        bordery[ii] - boundaryT : bordery[ii] + boundaryT + 1,
                    ]

                    intersectIdx = ApringIn[
                        borderx[ii] - boundaryT : borderx[ii] + boundaryT + 1,
                        bordery[ii] - boundaryT : bordery[ii] + boundaryT + 1,
                    ]

                    WestdWdn0[borderx[ii], bordery[ii]] = reg[
                        np.nonzero(intersectIdx)
                    ].mean()

                # Take Laplacian to find sensor signal estimate (Delta W = S)
                del2W = laplace(WestdWdn0) / dOmega

                # Extend the dimension of signal to the order of 2 for "fft" to
                # use
                Sest = padArray(del2W, padDim)

                # Put signal back inside boundary, leaving the rest of
                # Sestimate
                Sest[self.pMaskPad == 1] = Sini[self.pMaskPad == 1]

                # Need to recheck this condition
                S = Sest

            # Calculate the coefficient of normal/ annular Zernike polynomials
            if self.getCompensatorMode() == "zer":
                xSensor, ySensor = self._inst.getSensorCoor()
                zc = ZernikeMaskedFit(
                    West, xSensor, ySensor, numTerms, self.pMask, zobsR
                )
            else:
                zc = np.zeros(numTerms)

        elif PoissonSolver == "exp":

            # Use the integration method by serial expansion to solve the
            # Poisson's equation

            # Calculate I0 and dI
            I0, dI = self._getdIandI(I1, I2)

            # Get the x, y coordinate in mask. The element outside mask is 0.
            xSensor, ySensor = self._inst.getSensorCoor()
            xSensor = xSensor * self.cMask
            ySensor = ySensor * self.cMask

            # Create the F matrix and Zernike-related matrixes
            F = np.zeros(numTerms)
            dZidx = np.zeros((numTerms, dimOfDonut, dimOfDonut))
            dZidy = dZidx.copy()

            zcCol = np.zeros(numTerms)
            for ii in range(int(numTerms)):

                # Calculate the matrix for each Zk related component
                # Set the specific Zk cofficient to be 1 for the calculation
                zcCol[ii] = 1

                F[ii] = (
                    np.sum(dI * ZernikeAnnularEval(zcCol, xSensor, ySensor, zobsR))
                    * dOmega
                )
                dZidx[ii, :, :] = ZernikeAnnularGrad(
                    zcCol, xSensor, ySensor, zobsR, "dx"
                )
                dZidy[ii, :, :] = ZernikeAnnularGrad(
                    zcCol, xSensor, ySensor, zobsR, "dy"
                )

                # Set the specific Zk cofficient back to 0 to avoid interfering
                # other Zk's calculation
                zcCol[ii] = 0

            # Calculate Mij matrix, need to check the stability of integration
            # and symmetry later
            Mij = np.zeros([numTerms, numTerms])
            for ii in range(numTerms):
                for jj in range(numTerms):
                    Mij[ii, jj] = np.sum(
                        I0
                        * (
                            dZidx[ii, :, :].squeeze() * dZidx[jj, :, :].squeeze()
                            + dZidy[ii, :, :].squeeze() * dZidy[jj, :, :].squeeze()
                        )
                    )
            Mij = dOmega / (apertureDiameter / 2.0) ** 2 * Mij

            # Calculate dz
            focalLength = self._inst.getFocalLength()
            offset = self._inst.getDefocalDisOffset()
            dz = 2 * focalLength * (focalLength - offset) / offset

            # Define zc
            zc = np.zeros(numTerms)

            # Consider specific Zk terms only
            idx = self.getZernikeTerms()

            # Solve the equation: M*W = F => W = M^(-1)*F
            zc_tmp = np.linalg.lstsq(Mij[:, idx][idx], F[idx], rcond=None)[0] / dz
            zc[idx] = zc_tmp

            # Estimate the wavefront surface based on z4 - z22
            # z0 - z3 are set to be 0 instead
            West = ZernikeAnnularEval(
                np.concatenate(([0, 0, 0], zc[3:])), xSensor, ySensor, zobsR
            )

        return zc, West

    def _createSignal(self, I1, I2, cliplevel):
        """Calculate the wavefront singal for "fft" to use in solving the
        Poisson's equation.

        Need to discuss the method to define threshold and discuss to use
        np.median() instead.
        Need to discuss why the calculation of I0 is different from "exp".

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        cliplevel : float
            Parameter to determine the threshold of calculating I0.

        Returns
        -------
        numpy.ndarray
            Approximated wavefront signal.
        """

        # Check the condition of images
        I1image, I2image = self._checkImageDim(I1, I2)

        # Wavefront signal S=-(1/I0)*(dI/dz) is approximated to be
        # -(1/delta z)*(I1-I2)/(I1+I2)
        num = I1image - I2image
        den = I1image + I2image

        # Define the effective minimum central signal element by the threshold
        # ( I0=(I1+I2)/2 )

        # Calculate the threshold
        pixelList = den * self.cMask
        pixelList = pixelList[pixelList != 0]

        low = pixelList.min()
        high = pixelList.max()
        medianThreshold = (high - low) / 2.0 + low

        # Define the effective minimum central signal element
        den[den < medianThreshold * cliplevel] = 1.5 * medianThreshold

        # Calculate delta z = f(f-l)/l, f: focal length, l: defocus distance of
        # the image planes
        focalLength = self._inst.getFocalLength()
        offset = self._inst.getDefocalDisOffset()
        deltaZ = focalLength * (focalLength - offset) / offset

        # Calculate the wavefront signal. Enforce the element outside the mask
        # to be 0.
        den[den == 0] = np.inf

        # Calculate the wavefront signal
        S = num / den / deltaZ

        # Extend the dimension of signal to the order of 2 for "fft" to use
        padDim = self.getFftDimension()
        Sout = padArray(S, padDim) * self.cMaskPad

        return Sout

    def _getdIandI(self, I1, I2):
        """Calculate the central image and differential image to be used in the
        serial expansion method.

        It is noted that the images are assumed to be co-center already. And
        the intra-/ extra-focal image can overlap with one another after the
        rotation of 180 degree.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            Image data of I0.
        numpy.ndarray
            Differential image (dI) of I0.
        """

        # Check the condition of images
        I1image, I2image = self._checkImageDim(I1, I2)

        # Calculate the central image and differential iamge
        I0 = (I1image + I2image) / 2
        dI = I2image - I1image

        return I0, dI

    def _checkImageDim(self, I1, I2):
        """Check the dimension of images.

        It is noted that the I2 image is rotated by 180 degree.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            I1 defocal image.
        numpy.ndarray
            I2 defocal image. It is noted that the I2 image is rotated by 180
            degree.

        Raises
        ------
        Exception
            Check the dimension of images is n by n or not.
        Exception
            Check two defocal images have the same size or not.
        """

        # Check the condition of images
        m1, n1 = I1.getImg().shape
        m2, n2 = I2.getImg().shape

        if m1 != n1 or m2 != n2:
            raise Exception("Image is not square.")

        if m1 != m2 or n1 != n2:
            raise Exception("Images do not have the same size.")

        # Define I1
        I1image = I1.getImg()

        # Rotate the image by 180 degree through rotating two times of 90
        # degree
        I2image = np.rot90(I2.getImg(), k=2)

        return I1image, I2image

    def _makeMasterMask(self, I1, I2, poissonSolver=None):
        """Calculate the common mask of defocal images.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        poissonSolver : str, optional
            Algorithm to solve the Poisson's equation. If the "fft" is used,
            the mask dimension will be extended to the order of 2 for the "fft"
            to use. (the default is None.)
        """

        # Get the overlap region of mask for intra- and extra-focal images.
        # This is to avoid the anormalous signal due to difference in
        # vignetting.
        self.pMask = I1.getPaddedMask() * I2.getPaddedMask()
        self.cMask = I1.getNonPaddedMask() * I2.getNonPaddedMask()

        # Change the dimension of image for fft to use
        if poissonSolver == "fft":
            padDim = self.getFftDimension()
            self.pMaskPad = padArray(self.pMask, padDim)
            self.cMaskPad = padArray(self.cMask, padDim)

    def _applyI1I2pMask(self, I1, I2):
        """Correct the defocal images if I1 and I2 are belong to different
        sources.

        (There is a problem for this actually. If I1 and I2 come from different
        sources, what should the correction of TIE be? At this moment, the
        fieldX and fieldY of I1 and I2 should be different. And the sources are
        different also.)

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.

        Returns
        -------
        numpy.ndarray
            Corrected I1 image.
        numpy.ndarray
            Corrected I2 image.
        """

        # Get the overlap region of images and do the normalization.
        if I1.getFieldXY() != I2.getFieldXY():

            # Get the overlap region of image
            I1.updateImage(I1.getImg() * self.pMask)

            # Rotate the pMask by 180 degree through rotating two times of 90
            # degree because I2 has been rotated by 180 degree already.
            I2.updateImage(I2.getImg() * np.rot90(self.pMask, 2))

            # Do the normalization of image.
            I1.updateImage(I1.getImg() / np.sum(I1.getImg()))
            I2.updateImage(I2.getImg() / np.sum(I2.getImg()))

        # Return the correct images. It is noted that there is no need of
        # vignetting correction.
        # This is after masking already in _singleItr() or itr0().
        return I1, I2

    def _reset(self, I1, I2):
        """Reset the iteration time of outer loop and defocal images.

        Parameters
        ----------
        I1 : CompensableImage
            Intra- or extra-focal image.
        I2 : CompensableImage
            Intra- or extra-focal image.
        """

        # Reset the current iteration time to 0
        self.currentItr = 0

        # Show the reset information
        if self.debugLevel >= 3:
            print("Resetting images: I1 and I2")

        # Determine to reset the images or not based on the existence of
        # the attribute: Image.image0. Only after the first run of
        # inner loop, this attribute will exist.
        try:
            # Reset the images to the first beginning
            I1.updateImage(I1.getImgInit().copy())
            I2.updateImage(I2.getImgInit().copy())

            # Show the information of resetting image
            if self.debugLevel >= 3:
                print("Resetting images in inside.")

        except AttributeError:
            # Show the information of no image0
            if self.debugLevel >= 3:
                print("Image0 = None. This is the first time to run the code.")

            pass

    def outZer4Up(self, unit="nm", filename=None, showPlot=False):
        """Put the coefficients of normal/ annular Zernike polynomials on
        terminal or file ande show the image if it is needed.

        Parameters
        ----------
        unit : str, optional
            Unit of the coefficients of normal/ annular Zernike polynomials. It
            can be m, nm, or um. (the default is "nm".)
        filename : str, optional
            Name of output file. (the default is None.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)
        """

        # List of Zn,m
        Znm = [
            "Z0,0",
            "Z1,1",
            "Z1,-1",
            "Z2,0",
            "Z2,-2",
            "Z2,2",
            "Z3,-1",
            "Z3,1",
            "Z3,-3",
            "Z3,3",
            "Z4,0",
            "Z4,2",
            "Z4,-2",
            "Z4,4",
            "Z4,-4",
            "Z5,1",
            "Z5,-1",
            "Z5,3",
            "Z5,-3",
            "Z5,5",
            "Z5,-5",
            "Z6,0",
        ]

        # Decide the format of z based on the input unit (m, nm, or um)
        if unit == "m":
            z = self.zer4UpNm * 1e-9
        elif unit == "nm":
            z = self.zer4UpNm
        elif unit == "um":
            z = self.zer4UpNm * 1e-3
        else:
            print("Unknown unit: %s" % unit)
            print("Unit options are: m, nm, um")
            return

        # Write the coefficients into a file if needed.
        if filename is not None:
            f = open(filename, "w")
        else:
            f = sys.stdout

        for ii in range(4, len(z) + 4):
            f.write("Z%d (%s)\t %8.3f\n" % (ii, Znm[ii - 1], z[ii - 4]))

        # Close the file
        if filename is not None:
            f.close()

        # Show the plot
        if showPlot:
            zkIdx = range(4, len(z) + 4)
            plotZernike(zkIdx, z, unit)
Пример #27
0
def runWEP(instDir, algoFolderPath, useAlgorithm, imageFolderPath,
           intra_image_name, extra_image_name, fieldXY, opticalModel,
           showFig=False):
    """Calculate the coefficients of normal/ annular Zernike polynomials based
    on the provided instrument, algorithm, and optical model.

    Parameters
    ----------
    instDir : str
        Path to instrument folder.
    algoFolderPath : str
        Path to algorithm folder.
    useAlgorithm : str
        Algorithm to solve the Poisson's equation in the transport of intensity
        equation (TIE). It can be "fft" or "exp" here.
    imageFolderPath : str
        Path to image folder.
    intra_image_name : str
        File name of intra-focal image.
    extra_image_name : str
        File name of extra-focal image.
    fieldXY : tuple
        Position of donut on the focal plane in degree for intra- and
        extra-focal images.
    opticalModel : str
        Optical model. It can be "paraxial", "onAxis", or "offAxis".
    showFig : bool, optional
        Show the wavefront image and compenstated image or not. (the default is
        False.)

    Returns
    -------
    numpy.ndarray
        Coefficients of Zernike polynomials (z4 - z22).
    """

    # Image files Path
    intra_image_file = os.path.join(imageFolderPath, intra_image_name)
    extra_image_file = os.path.join(imageFolderPath, extra_image_name)

    # There is the difference between intra and extra images
    # I1: intra_focal images, I2: extra_focal Images
    I1 = CompensableImage()
    I2 = CompensableImage()

    I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
    I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

    # Set the instrument
    inst = Instrument(instDir)
    inst.config(CamType.LsstCam, I1.getImgSizeInPix(),
                announcedDefocalDisInMm=1.0)

    # Define the algorithm to be used.
    algo = Algorithm(algoFolderPath)
    algo.config(useAlgorithm, inst, debugLevel=0)

    # Plot the original wavefront images
    if (showFig):
        plotImage(I1.image, title="intra image")
        plotImage(I2.image, title="extra image")

    # Run it
    algo.runIt(I1, I2, opticalModel, tol=1e-3)

    # Show the Zernikes Zn (n>=4)
    algo.outZer4Up(showPlot=False)

    # Plot the final conservated images and wavefront
    if (showFig):
        plotImage(I1.image, title="Compensated intra image")
        plotImage(I2.image, title="Compensated extra image")

        # Plot the Wavefront
        plotImage(algo.wcomp, title="Final wavefront")
        plotImage(algo.wcomp, title="Final wavefront with pupil mask applied",
                  mask=algo.pMask)

    # Return the Zernikes Zn (n>=4)
    return algo.getZer4UpInNm()
Пример #28
0
    def makeTemplate(
        self,
        sensorName,
        defocalType,
        imageSize,
        camType=CamType.LsstCam,
        opticalModel="offAxis",
        pixelScale=0.2,
    ):
        """Make the donut template image.

        Parameters
        ----------
        sensorName : str
            The camera detector for which we want to make a template. Should
            be in "Rxx_Sxx" format.
        defocalType : enum 'DefocalType'
            The defocal state of the sensor.
        imageSize : int
            Size of template in pixels. The template will be a square.
        camType : enum 'CamType', optional
            Camera type. (The default is CamType.LsstCam)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis".
            (The default is "offAxis")
        pixelScale : float, optional
            The pixels to arcseconds conversion factor. (The default is 0.2)

        Returns
        -------
        numpy.ndarray [int]
            The donut template as a binary image.

        Raises
        ------
        ValueError
            Camera type is not supported.
        """

        configDir = getConfigDir()

        # Load Instrument parameters
        instDir = os.path.join(configDir, "cwfs", "instData")
        inst = Instrument(instDir)

        if camType in (CamType.LsstCam, CamType.LsstFamCam, CamType.ComCam):
            inst.config(camType, imageSize)
            focalPlaneLayout = readPhoSimSettingData(configDir,
                                                     "focalplanelayout.txt",
                                                     "fieldCenter")

            pixelSizeInUm = float(focalPlaneLayout[sensorName][2])

            sensorXMicron, sensorYMicron = np.array(
                focalPlaneLayout[sensorName][:2], dtype=float)

        elif camType == CamType.AuxTel:
            # AuxTel only works with onAxis sources
            if opticalModel != "onAxis":
                raise ValueError(
                    str(f"Optical Model {opticalModel} not supported with AuxTel. "
                        + "Must use 'onAxis'."))
            # Defocal distance for Latiss in mm
            # for LsstCam can use the default
            # hence only need to set here
            announcedDefocalDisInMm = getDefocalDisInMm("auxTel")
            inst.config(camType, imageSize, announcedDefocalDisInMm)
            # load the info for auxTel
            pixelSizeInMeters = inst.getCamPixelSize()  # pixel size in meters.
            pixelSizeInUm = pixelSizeInMeters * 1e6

            camera = obs_lsst.Latiss.getCamera()
            sensorName = list(
                camera.getNameIter())[0]  # only one detector in latiss
            detector = camera.get(sensorName)
            xp, yp = detector.getCenter(
                cameraGeom.FOCAL_PLANE)  # center of CCD in mm

            # multiply by 1000 to for mm --> microns conversion
            sensorXMicron = yp * 1000
            sensorYMicron = xp * 1000

        else:
            raise ValueError("Camera type (%s) is not supported." % camType)

        # Create image for mask
        img = CompensableImage()

        # Convert pixel locations to degrees
        sensorXPixel = float(sensorXMicron) / pixelSizeInUm
        sensorYPixel = float(sensorYMicron) / pixelSizeInUm

        # Multiply by pixelScale then divide by 3600 for arcsec->deg conversion
        sensorXDeg = sensorXPixel * pixelScale / 3600
        sensorYDeg = sensorYPixel * pixelScale / 3600
        fieldXY = [sensorXDeg, sensorYDeg]

        # Define position of donut at center of current sensor in degrees
        boundaryT = 0
        maskScalingFactorLocal = 1
        img.setImg(fieldXY,
                   defocalType,
                   image=np.zeros((imageSize, imageSize)))
        img.makeMask(inst, opticalModel, boundaryT, maskScalingFactorLocal)

        return img.getNonPaddedMask()
Пример #29
0
class WfEstimator(object):

    def __init__(self, instDir, algoDir):
        """Initialize the wavefront estimator class.

        Parameters
        ----------
        instDir : str
            Path to instrument directory.
        algoDir : str
            Path to algorithm directory.
        """

        self.inst = Instrument(instDir)
        self.algo = Algorithm(algoDir)

        self.imgIntra = CompensableImage()
        self.imgExtra = CompensableImage()

        self.opticalModel = ""
        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensableImage
            Intra-focal donut image.
        """

        return self.imgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensableImage
            Extra-focal donut image.
        """

        return self.imgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(self, solver="exp", camType=CamType.LsstCam,
               opticalModel="offAxis", defocalDisInMm=None, sizeInPix=120,
               debugLevel=0):
        """Configure the TIE solver.

        Parameters
        ----------
        solver : str, optional
            Algorithm to solve the Poisson's equation in the transport of
            intensity equation (TIE). It can be "fft" or "exp" here. (the
            default is "exp".)
        camType : enum 'CamType'
            Camera type. (the default is CamType.LsstCam.)
        opticalModel : str, optional
            Optical model. It can be "paraxial", "onAxis", or "offAxis". (the
            default is "offAxis".)
        defocalDisInMm : float, optional
            Defocal distance in mm. (the default is None.)
        sizeInPix : int, optional
            Wavefront image pixel size. (the default is 120.)
        debugLevel : int, optional
            Show the information under the running. If the value is higher,
            the information shows more. It can be 0, 1, 2, or 3. (the default
            is 0.)

        Raises
        ------
        ValueError
            Wrong Poisson solver name.
        ValueError
            Wrong optical model.
        """

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

        # Update the isnstrument name
        if (defocalDisInMm is None):
            defocalDisInMm = 1.5

        self.sizeInPix = int(sizeInPix)
        self.inst.config(camType, self.sizeInPix,
                         announcedDefocalDisInMm=defocalDisInMm)

        self.algo.config(solver, self.inst, debugLevel=debugLevel)

    def setImg(self, fieldXY, defocalType, image=None, imageFile=None):
        """Set the wavefront image.

        Parameters
        ----------
        fieldXY : tuple or list
            Position of donut on the focal plane in degree for intra- and
            extra-focal images.
        defocalType : enum 'DefocalType'
            Defocal type of image.
        image : numpy.ndarray, optional
            Array of image. (the default is None.)
        imageFile : str, optional
            Path of image file. (the default is None.)
        """

        if (defocalType == DefocalType.Intra):
            img = self.imgIntra
        elif (defocalType == DefocalType.Extra):
            img = self.imgExtra

        img.setImg(fieldXY, defocalType, image=image, imageFile=imageFile)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """Calculate the wavefront error.

        Parameters
        ----------
        tol : float, optional
            [description] (the default is 1e-3.)
        showZer : bool, optional
            Decide to show the annular Zernike polynomails or not. (the default
            is False.)
        showPlot : bool, optional
            Decide to show the plot or not. (the default is False.)

        Returns
        -------
        numpy.ndarray
            Coefficients of Zernike polynomials (z4 - z22).

        Raises
        ------
        RuntimeError
            Input image shape is wrong.
        """

        # Check the image size
        for img in (self.imgIntra, self.imgExtra):
            d1, d2 = img.getImg().shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError("Input image shape is (%d, %d), not required (%d, %d)" % (
                    d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.imgIntra, self.imgExtra, self.opticalModel, tol=tol)

        # Show the Zernikes Zn (n>=4)
        if (showZer):
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.getZer4UpInNm()
class CompensationImageDecoratorTest(unittest.TestCase):
    """Test the CompensationImageDecorator class."""

    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()
        
        # Define the instrument folder
        instruFolder = os.path.join(modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the instrument name
        instruName = "lsst"
        sensorSamples = 120

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, sensorSamples)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        self.fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(modulePath, "tests", "testData",
                                     "testImages", "validation",
                                     "LSST_NE_SN25_z11_0.25_exp.txt")
        self.zcCol = np.loadtxt(zcAnsFilePath)

    def testFunc(self):

        # Declare the CompensationImageDecorator
        wfsImg = CompensationImageDecorator()

        # Test to set the image
        wfsImg.setImg(self.fieldXY, imageFile=self.imgFilePathIntra,
                      atype="intra")
        self.assertEqual(wfsImg.image.shape, (120, 120))
        self.assertEqual(wfsImg.fieldX, self.fieldXY[0])
        self.assertEqual(wfsImg.fieldY, self.fieldXY[1])

        # Test to update the image0
        wfsImg.updateImage0()
        self.assertEqual(np.sum(np.abs(wfsImg.image0-wfsImg.image)), 0)

        # Test to get the off axis correction
        offAxisCorrOrder = 10
        instDir = os.path.join(self.inst.instDir, self.inst.instName)
        wfsImg.getOffAxisCorr(instDir, offAxisCorrOrder)
        self.assertEqual(wfsImg.offAxisOffset, 0.001)
        self.assertEqual(wfsImg.offAxis_coeff.shape, (4, 66))
        self.assertAlmostEqual(wfsImg.offAxis_coeff[0, 0], -2.6362089*1e-3)

        # Test to make the mask list
        model = "paraxial"
        masklist = wfsImg.makeMaskList(self.inst, model)
        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0]])
        self.assertEqual(np.sum(np.abs(masklist-masklistAns)), 0)

        model = "offAxis"
        masklist = wfsImg.makeMaskList(self.inst, model)
        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0], 
                                [-0.21240585, -0.21240585, 1.2300922, 1], 
                                [-0.08784336, -0.08784336, 0.55802573, 0]])
        self.assertAlmostEqual(np.sum(np.abs(masklist-masklistAns)), 0)

        # Test the fuction to make the mask

        # The unit of boundary_thickness (boundaryT) is pixel
        boundaryT = 8
        maskScalingFactorLocal = 1
        model = "offAxis"
        wfsImg.makeMask(self.inst, model, boundaryT, maskScalingFactorLocal)
        self.assertEqual(wfsImg.pMask.shape, wfsImg.image.shape)
        self.assertEqual(wfsImg.cMask.shape, wfsImg.image.shape)
        self.assertEqual(np.sum(np.abs(wfsImg.cMask - wfsImg.pMask)), 3001)

        # Test the function of image co-center
        wfsImg.imageCoCenter(self.inst)
        xc, yc = wfsImg.getCenterAndR_ef()[0:2]
        self.assertEqual(int(xc), 63)
        self.assertEqual(int(yc), 63)

    def testFuncCompensation(self):

        # Generate a fake algorithm class
        algo = tempAlgo()
        algo.parameter = {"numTerms": 22, "offAxisPolyOrder": 10, "zobsR": 0.61}

        # Test the function of image compensation
        boundaryT = 8
        offAxisCorrOrder = 10
        zcCol = np.zeros(22)
        zcCol[3:] = self.zcCol*1e-9
        instDir = os.path.join(self.inst.instDir, self.inst.instName)

        wfsImgIntra = CompensationImageDecorator()
        wfsImgExtra = CompensationImageDecorator()
        wfsImgIntra.setImg(self.fieldXY, imageFile=self.imgFilePathIntra,
                           atype="intra")
        wfsImgExtra.setImg(self.fieldXY, imageFile=self.imgFilePathExtra,
                           atype="extra")

        for wfsImg in [wfsImgIntra, wfsImgExtra]:
            wfsImg.makeMask(self.inst, self.opticalModel, boundaryT, 1)
            wfsImg.getOffAxisCorr(instDir, offAxisCorrOrder)
            wfsImg.imageCoCenter(self.inst)
            wfsImg.compensate(self.inst, algo, zcCol, self.opticalModel)

        # Get the common region
        binaryImgIntra = wfsImgIntra.getCenterAndR_ef()[3]
        binaryImgExtra = wfsImgExtra.getCenterAndR_ef()[3]
        binaryImg = binaryImgIntra + binaryImgExtra
        binaryImg[binaryImg < 2] = 0
        binaryImg = binaryImg / 2

        # Calculate the difference
        intraImg = wfsImgIntra.getImg()
        extraImg = wfsImgExtra.getImg()
        res = np.sum(np.abs(intraImg - extraImg) * binaryImg)
        self.assertLess(res, 500)
Пример #31
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""
    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the instrument folder
        instruFolder = os.path.join(self.modulePath, "configData", "cwfs",
                                    "instruData")

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the instrument name
        instruName = "lsst"

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "fft"

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        # self.I1 = Image.Image()
        # self.I2 = Image.Image()

        self.I1 = CompensationImageDecorator()
        self.I2 = CompensationImageDecorator()

        self.I1.setImg(fieldXY, imageFile=intra_image_file, atype="intra")
        self.I2.setImg(fieldXY, imageFile=extra_image_file, atype="extra")

        self.inst = Instrument(instruFolder)
        self.inst.config(instruName, self.I1.sizeinPix)

    def testExp(self):

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "exp"

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, self.inst, debugLevel=3)
        algo.setDebugLevel(0)
        self.assertEqual(algo.debugLevel, 0)

        # Test functions: itr0() and nextItr()
        algo.itr0(self.inst, self.I1, self.I2, self.opticalModel)
        tmp1 = algo.zer4UpNm
        algo.nextItr(self.inst, self.I1, self.I2, self.opticalModel, nItr=2)
        algo.itr0(self.inst, self.I1, self.I2, self.opticalModel)
        tmp2 = algo.zer4UpNm

        difference = np.sum(np.abs(tmp1 - tmp2))
        self.assertEqual(difference, 0)

        # Run it
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY, image=self.I1.image0, atype=self.I1.atype)
        self.I2.setImg(fieldXY, image=self.I2.image0, atype=self.I2.atype)
        algo.reset()
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)

    def testFFT(self):

        # Define the algorithm folder
        algoFolderPath = os.path.join(self.modulePath, "configData", "cwfs",
                                      "algo")

        # Define the algorithm being used: "exp" or "fft"
        useAlgorithm = "fft"

        # Define the algorithm to be used.
        algo = Algorithm(algoFolderPath)
        algo.config(useAlgorithm, self.inst, debugLevel=0)

        # Run it
        algo.runIt(self.inst, self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = algo.zer4UpNm
        self.assertEqual(int(Zk[7]), -192)
Пример #32
0
class TestCompensableImage(unittest.TestCase):
    """Test the CompensableImage class."""
    def setUp(self):

        # Get the path of module
        modulePath = getModulePath()

        # Define the instrument folder
        instDir = os.path.join(getConfigDir(), "cwfs", "instData")

        # Define the instrument name
        dimOfDonutOnSensor = 120

        self.inst = Instrument(instDir)
        self.inst.config(CamType.LsstCam,
                         dimOfDonutOnSensor,
                         announcedDefocalDisInMm=1.0)

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"
        self.imgFilePathIntra = os.path.join(imageFolderPath, intra_image_name)
        self.imgFilePathExtra = os.path.join(imageFolderPath, extra_image_name)

        # This is the position of donut on the focal plane in degree
        self.fieldXY = (1.185, 1.185)

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Get the true Zk
        zcAnsFilePath = os.path.join(
            modulePath,
            "tests",
            "testData",
            "testImages",
            "validation",
            "simulation",
            "LSST_NE_SN25_z11_0.25_exp.txt",
        )
        self.zcCol = np.loadtxt(zcAnsFilePath)

        self.wfsImg = CompensableImage()

    def testGetDefocalType(self):

        defocalType = self.wfsImg.getDefocalType()
        self.assertEqual(defocalType, DefocalType.Intra)

    def testGetImgObj(self):

        imgObj = self.wfsImg.getImgObj()
        self.assertTrue(isinstance(imgObj, Image))

    def testGetImg(self):

        img = self.wfsImg.getImg()
        self.assertTrue(isinstance(img, np.ndarray))
        self.assertEqual(len(img), 0)

    def testGetImgSizeInPix(self):

        imgSizeInPix = self.wfsImg.getImgSizeInPix()
        self.assertEqual(imgSizeInPix, 0)

    def testGetOffAxisCoeff(self):

        offAxisCoeff, offAxisOffset = self.wfsImg.getOffAxisCoeff()
        self.assertTrue(isinstance(offAxisCoeff, np.ndarray))
        self.assertEqual(len(offAxisCoeff), 0)
        self.assertEqual(offAxisOffset, 0.0)

    def testGetImgInit(self):

        imgInit = self.wfsImg.getImgInit()
        self.assertEqual(imgInit, None)

    def testIsCaustic(self):

        self.assertFalse(self.wfsImg.isCaustic())

    def testGetPaddedMask(self):

        pMask = self.wfsImg.getPaddedMask()
        self.assertEqual(len(pMask), 0)
        self.assertEqual(pMask.dtype, int)

    def testGetNonPaddedMask(self):

        cMask = self.wfsImg.getNonPaddedMask()
        self.assertEqual(len(cMask), 0)
        self.assertEqual(cMask.dtype, int)

    def testGetFieldXY(self):

        fieldX, fieldY = self.wfsImg.getFieldXY()
        self.assertEqual(fieldX, 0)
        self.assertEqual(fieldY, 0)

    def testSetImg(self):

        self._setIntraImg()
        self.assertEqual(self.wfsImg.getImg().shape, (120, 120))

    def _setIntraImg(self):

        self.wfsImg.setImg(self.fieldXY,
                           DefocalType.Intra,
                           imageFile=self.imgFilePathIntra)

    def testUpdateImage(self):

        self._setIntraImg()

        newImg = np.random.rand(5, 5)
        self.wfsImg.updateImage(newImg)

        self.assertTrue(np.all(self.wfsImg.getImg() == newImg))

    def testUpdateImgInit(self):

        self._setIntraImg()

        self.wfsImg.updateImgInit()

        delta = np.sum(np.abs(self.wfsImg.getImgInit() - self.wfsImg.getImg()))
        self.assertEqual(delta, 0)

    def testImageCoCenter(self):

        self._setIntraImg()

        self.wfsImg.imageCoCenter(self.inst)

        xc, yc = self.wfsImg.getImgObj().getCenterAndR()[0:2]
        self.assertEqual(int(xc), 63)
        self.assertEqual(int(yc), 63)

    def testCompensate(self):

        # Generate a fake algorithm class
        algo = TempAlgo()

        # Test the function of image compensation
        boundaryT = 8
        offAxisCorrOrder = 10
        zcCol = np.zeros(22)
        zcCol[3:] = self.zcCol * 1e-9

        wfsImgIntra = CompensableImage()
        wfsImgExtra = CompensableImage()
        wfsImgIntra.setImg(
            self.fieldXY,
            DefocalType.Intra,
            imageFile=self.imgFilePathIntra,
        )
        wfsImgExtra.setImg(self.fieldXY,
                           DefocalType.Extra,
                           imageFile=self.imgFilePathExtra)

        for wfsImg in [wfsImgIntra, wfsImgExtra]:
            wfsImg.makeMask(self.inst, self.opticalModel, boundaryT, 1)
            wfsImg.setOffAxisCorr(self.inst, offAxisCorrOrder)
            wfsImg.imageCoCenter(self.inst)
            wfsImg.compensate(self.inst, algo, zcCol, self.opticalModel)

        # Get the common region
        intraImg = wfsImgIntra.getImg()
        extraImg = wfsImgExtra.getImg()

        centroid = CentroidRandomWalk()
        binaryImgIntra = centroid.getImgBinary(intraImg)
        binaryImgExtra = centroid.getImgBinary(extraImg)

        binaryImg = binaryImgIntra + binaryImgExtra
        binaryImg[binaryImg < 2] = 0
        binaryImg = binaryImg / 2

        # Calculate the difference
        res = np.sum(np.abs(intraImg - extraImg) * binaryImg)
        self.assertLess(res, 500)

    def testCenterOnProjection(self):

        template = self._prepareGaussian2D(100, 1)

        dx = 2
        dy = 8
        img = np.roll(np.roll(template, dx, axis=1), dy, axis=0)
        np.roll(np.roll(img, dx, axis=1), dy, axis=0)

        self.assertGreater(np.sum(np.abs(img - template)), 29)

        imgRecenter = self.wfsImg.centerOnProjection(img, template, window=20)
        self.assertLess(np.sum(np.abs(imgRecenter - template)), 1e-7)

    def _prepareGaussian2D(self, imgSize, sigma):

        x = np.linspace(-10, 10, imgSize)
        y = np.linspace(-10, 10, imgSize)

        xx, yy = np.meshgrid(x, y)

        return (1 / (2 * np.pi * sigma**2) *
                np.exp(-(xx**2 / (2 * sigma**2) + yy**2 / (2 * sigma**2))))

    def testSetOffAxisCorr(self):

        self._setIntraImg()

        offAxisCorrOrder = 10
        self.wfsImg.setOffAxisCorr(self.inst, offAxisCorrOrder)

        offAxisCoeff, offAxisOffset = self.wfsImg.getOffAxisCoeff()
        self.assertEqual(offAxisCoeff.shape, (4, 66))
        self.assertAlmostEqual(offAxisCoeff[0, 0], -2.6362089 * 1e-3)
        self.assertEqual(offAxisOffset, 0.001)

    def testMakeMaskListOfParaxial(self):

        self._setIntraImg()

        model = "paraxial"
        masklist = self.wfsImg.makeMaskList(self.inst, model)

        masklistAns = np.array([[0, 0, 1, 1], [0, 0, 0.61, 0]])
        self.assertEqual(np.sum(np.abs(masklist - masklistAns)), 0)

    def testMakeMaskListOfOffAxis(self):

        self._setIntraImg()

        model = "offAxis"
        masklist = self.wfsImg.makeMaskList(self.inst, model)

        masklistAns = np.array([
            [0, 0, 1, 1],
            [0, 0, 0.61, 0],
            [-0.21240585, -0.21240585, 1.2300922, 1],
            [-0.08784336, -0.08784336, 0.55802573, 0],
        ])
        self.assertAlmostEqual(np.sum(np.abs(masklist - masklistAns)), 0)

    def testMakeMask(self):

        self._setIntraImg()

        boundaryT = 8
        maskScalingFactorLocal = 1
        model = "offAxis"
        self.wfsImg.makeMask(self.inst, model, boundaryT,
                             maskScalingFactorLocal)

        image = self.wfsImg.getImg()
        pMask = self.wfsImg.getPaddedMask()
        cMask = self.wfsImg.getNonPaddedMask()
        self.assertEqual(pMask.shape, image.shape)
        self.assertEqual(cMask.shape, image.shape)
        self.assertEqual(np.sum(np.abs(cMask - pMask)), 3001)
Пример #33
0
class TestAlgorithm(unittest.TestCase):
    """Test the Algorithm class."""

    def setUp(self):

        # Get the path of module
        self.modulePath = getModulePath()

        # Define the image folder and image names
        # Image data -- Don't know the final image format.
        # It is noted that image.readFile inuts is based on the txt file
        imageFolderPath = os.path.join(self.modulePath, "tests", "testData",
                                       "testImages", "LSST_NE_SN25")
        intra_image_name = "z11_0.25_intra.txt"
        extra_image_name = "z11_0.25_extra.txt"

        # Define fieldXY: [1.185, 1.185] or [0, 0]
        # This is the position of donut on the focal plane in degree
        fieldXY = [1.185, 1.185]

        # Define the optical model: "paraxial", "onAxis", "offAxis"
        self.opticalModel = "offAxis"

        # Image files Path
        intra_image_file = os.path.join(imageFolderPath, intra_image_name)
        extra_image_file = os.path.join(imageFolderPath, extra_image_name)

        # Theree is the difference between intra and extra images
        # I1: intra_focal images, I2: extra_focal Images
        self.I1 = CompensableImage()
        self.I2 = CompensableImage()

        self.I1.setImg(fieldXY, DefocalType.Intra, imageFile=intra_image_file)
        self.I2.setImg(fieldXY, DefocalType.Extra, imageFile=extra_image_file)

        # Set up the instrument
        cwfsConfigDir = os.path.join(getConfigDir(), "cwfs")

        instDir = os.path.join(cwfsConfigDir, "instData")
        self.inst = Instrument(instDir)

        self.inst.config(CamType.LsstCam, self.I1.getImgSizeInPix(),
                         announcedDefocalDisInMm=1.0)

        # Set up the algorithm
        algoDir = os.path.join(cwfsConfigDir, "algo")

        self.algoExp = Algorithm(algoDir)
        self.algoExp.config("exp", self.inst)

        self.algoFft = Algorithm(algoDir)
        self.algoFft.config("fft", self.inst)

    def testGetDebugLevel(self):

        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testSetDebugLevel(self):

        self.algoExp.config("exp", self.inst, debugLevel=3)
        self.assertEqual(self.algoExp.getDebugLevel(), 3)

        self.algoExp.setDebugLevel(0)
        self.assertEqual(self.algoExp.getDebugLevel(), 0)

    def testGetZer4UpInNm(self):

        zer4UpNm = self.algoExp.getZer4UpInNm()
        self.assertTrue(isinstance(zer4UpNm, np.ndarray))

    def testGetPoissonSolverName(self):

        self.assertEqual(self.algoExp.getPoissonSolverName(), "exp")
        self.assertEqual(self.algoFft.getPoissonSolverName(), "fft")

    def testGetNumOfZernikes(self):

        self.assertEqual(self.algoExp.getNumOfZernikes(), 22)
        self.assertEqual(self.algoFft.getNumOfZernikes(), 22)

    def testGetZernikeTerms(self):

        zTerms = self.algoExp.getZernikeTerms()
        self.assertTrue(zTerms.dtype, int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())
        self.assertEqual(zTerms[0], 1)
        self.assertEqual(zTerms[-1], self.algoExp.getNumOfZernikes())

        zTerms = self.algoFft.getZernikeTerms()
        self.assertTrue(zTerms.dtype, int)
        self.assertEqual(len(zTerms), self.algoExp.getNumOfZernikes())

    def testGetObsOfZernikes(self):

        self.assertEqual(self.algoExp.getObsOfZernikes(),
                         self.inst.getObscuration())
        self.assertEqual(self.algoFft.getObsOfZernikes(),
                         self.inst.getObscuration())

    def testGetNumOfOuterItr(self):

        self.assertEqual(self.algoExp.getNumOfOuterItr(), 14)
        self.assertEqual(self.algoFft.getNumOfOuterItr(), 14)

    def testGetNumOfInnerItr(self):

        self.assertEqual(self.algoFft.getNumOfInnerItr(), 6)

    def testGetFeedbackGain(self):

        self.assertEqual(self.algoExp.getFeedbackGain(), 0.6)
        self.assertEqual(self.algoFft.getFeedbackGain(), 0.6)

    def testGetOffAxisPolyOrder(self):

        self.assertEqual(self.algoExp.getOffAxisPolyOrder(), 10)
        self.assertEqual(self.algoFft.getOffAxisPolyOrder(), 10)

    def testGetCompensatorMode(self):

        self.assertEqual(self.algoExp.getCompensatorMode(), "zer")
        self.assertEqual(self.algoFft.getCompensatorMode(), "zer")

    def testGetCompSequence(self):

        compSequence = self.algoExp.getCompSequence()
        self.assertTrue(isinstance(compSequence, np.ndarray))
        self.assertEqual(compSequence.dtype, int)
        self.assertEqual(len(compSequence), self.algoExp.getNumOfOuterItr())
        self.assertEqual(compSequence[0], 4)
        self.assertEqual(compSequence[-1], 22)

        compSequence = self.algoFft.getCompSequence()
        self.assertEqual(len(compSequence), self.algoFft.getNumOfOuterItr())

    def testGetBoundaryThickness(self):

        self.assertEqual(self.algoExp.getBoundaryThickness(), 8)
        self.assertEqual(self.algoFft.getBoundaryThickness(), 1)

    def testGetFftDimension(self):

        self.assertEqual(self.algoFft.getFftDimension(), 128)

    def testGetSignalClipSequence(self):

        sumclipSequence = self.algoFft.getSignalClipSequence()
        self.assertTrue(isinstance(sumclipSequence, np.ndarray))
        self.assertEqual(len(sumclipSequence), self.algoExp.getNumOfOuterItr()+1)
        self.assertEqual(sumclipSequence[0], 0.33)
        self.assertEqual(sumclipSequence[-1], 0.51)

    def testGetMaskScalingFactor(self):

        self.assertAlmostEqual(self.algoExp.getMaskScalingFactor(), 1.0939,
                               places=4)
        self.assertAlmostEqual(self.algoFft.getMaskScalingFactor(), 1.0939,
                               places=4)

    def testRunItOfExp(self):

        # Test functions: itr0() and nextItr()
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp1 = self.algoExp.getZer4UpInNm()
        self.algoExp.nextItr(self.I1, self.I2, self.opticalModel, nItr=2)
        self.algoExp.itr0(self.I1, self.I2, self.opticalModel)
        tmp2 = self.algoExp.getZer4UpInNm()

        difference = np.sum(np.abs(tmp1-tmp2))
        self.assertEqual(difference, 0)

        # Run it
        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        # Check the value
        Zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(Zk[7]), -192)

        # Reset and check the calculation again
        fieldXY = [self.I1.fieldX, self.I1.fieldY]
        self.I1.setImg(fieldXY, self.I1.getDefocalType(), image=self.I1.getImgInit())
        self.I2.setImg(fieldXY, self.I2.getDefocalType(), image=self.I2.getImgInit())
        self.algoExp.reset()
        self.algoExp.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)
        Zk = self.algoExp.getZer4UpInNm()
        self.assertEqual(int(Zk[7]), -192)

    def testRunItOfFft(self):

        self.algoFft.runIt(self.I1, self.I2, self.opticalModel, tol=1e-3)

        zk = self.algoFft.getZer4UpInNm()
        self.assertEqual(int(zk[7]), -192)
Пример #34
0
 def testInstrument(self):
     inst = Instrument(self.instruFolder)
     inst.config(self.instruName, 120)
     self.assertEqual(inst.parameter["sensorSamples"], 120)
Пример #35
0
class WfEstimator(object):
    def __init__(self, instruFolderPath, algoFolderPath):
        """

        Initialize the wavefront estimator class.

        Arguments:
            instruFolderPath {[str]} -- Path to instrument directory.
            algoFolderPath {[str]} -- Path to algorithm directory.
        """

        self.algo = Algorithm(algoFolderPath)
        self.inst = Instrument(instruFolderPath)
        self.ImgIntra = CompensationImageDecorator()
        self.ImgExtra = CompensationImageDecorator()
        self.opticalModel = ""

        self.sizeInPix = 0

    def getAlgo(self):
        """Get the algorithm object.

        Returns
        -------
        Algorithm
            Algorithm object.
        """

        return self.algo

    def getInst(self):
        """Get the instrument object.

        Returns
        -------
        Instrument
            Instrument object.
        """

        return self.inst

    def getIntraImg(self):
        """Get the intra-focal donut image.

        Returns
        -------
        CompensationImageDecorator
            Intra-focal donut image.
        """

        return self.ImgIntra

    def getExtraImg(self):
        """Get the extra-focal donut image.

        Returns
        -------
        CompensationImageDecorator
            Extra-focal donut image.
        """

        return self.ImgExtra

    def getOptModel(self):
        """Get the optical model.

        Returns
        -------
        str
            Optical model.
        """

        return self.opticalModel

    def getSizeInPix(self):
        """Get the donut image size in pixel defined by the config() function.

        Returns
        -------
        int
            Donut image size in pixel
        """

        return self.sizeInPix

    def reset(self):
        """

        Reset the calculation for the new input images with the same algorithm
        settings.
        """

        self.algo.reset()

    def config(self,
               solver="exp",
               instName="lsst",
               opticalModel="offAxis",
               defocalDisInMm=None,
               sizeInPix=120,
               debugLevel=0):
        """
        
        Configure the TIE solver.
        
        Keyword Arguments:
            solver {[str]} -- Algorithm to solve the Poisson's equation in the transport of 
                            intensity equation (TIE). It can be "fft" or "exp" here. 
                            (default: {"exp"})
            instName {[str]} -- Instrument name. It is "lsst" in the baseline. (default: {"lsst"})
            opticalModel {[str]} -- Optical model. It can be "paraxial", "onAxis", or "offAxis". 
                                    (default: {"offAxis"})
            defocalDisInMm {[float]} -- Defocal distance in mm. (default: {None})
            sizeInPix {[int]} -- Wavefront image pixel size. (default: {120}) 
            debugLevel {[int]} -- Show the information under the running. If the value is higher, 
                                    the information shows more. It can be 0, 1, 2, or 3. 
                                    (default: {0})

        Raises:
            ValueError -- Wrong instrument name.
            ValueError -- No intra-focal image.
            ValueError -- Wrong Poisson solver name.
            ValueError -- Wrong optical model.
        """

        # Check the inputs and assign the parameters used in the TIE
        # Need to change the way to hold the instance of Instrument and Algorithm

        # Update the isnstrument name
        if (defocalDisInMm is not None):
            instName = instName + str(int(10 * defocalDisInMm))

        if instName not in ("lsst", "lsst05", "lsst10", "lsst15", "lsst20",
                            "lsst25", "comcam10", "comcam15", "comcam20"):
            raise ValueError("Instrument can not be '%s'." % instName)

        # Set the available wavefront image size (n x n)
        self.sizeInPix = int(sizeInPix)

        # Configurate the instrument
        self.inst.config(instName, self.sizeInPix)

        if solver not in ("exp", "fft"):
            raise ValueError("Poisson solver can not be '%s'." % solver)
        else:
            self.algo.config(solver, self.inst, debugLevel=debugLevel)

        if opticalModel not in ("paraxial", "onAxis", "offAxis"):
            raise ValueError("Optical model can not be '%s'." % opticalModel)
        else:
            self.opticalModel = opticalModel

    def setImg(self, fieldXY, image=None, imageFile=None, defocalType=None):
        """

        Set the wavefront image.

        Arguments:
            fieldXY {[float]} -- Position of donut on the focal plane in degree for intra- 
                                 and extra-focal images.

        Keyword Arguments:
            image {[float]} -- Array of image. (default: {None})
            imageFile {[str]} -- Path of image file. (default: {None})
            defocalType {[str]} -- Type of image. It should be "intra" or "extra". 
                                    (default: {None})

        Raises:
            ValueError -- Wrong defocal type.
        """

        # Check the defocal type
        if defocalType not in (self.ImgIntra.INTRA, self.ImgIntra.EXTRA):
            raise ValueError("Defocal type can not be '%s'." % defocalType)

        # Read the image and assign the type
        if (defocalType == self.ImgIntra.INTRA):
            self.ImgIntra.setImg(fieldXY,
                                 image=image,
                                 imageFile=imageFile,
                                 atype=defocalType)
        elif (defocalType == self.ImgIntra.EXTRA):
            self.ImgExtra.setImg(fieldXY,
                                 image=image,
                                 imageFile=imageFile,
                                 atype=defocalType)

    def calWfsErr(self, tol=1e-3, showZer=False, showPlot=False):
        """

        Calculate the wavefront error.

        Keyword Arguments:
            tol {number} -- Tolerance of difference of coefficients of Zk polynomials compared 
                            with the previours iteration. (default: {1e-3})
            showZer {bool} -- Decide to show the annular Zernike polynomails or not. 
                                (default: {False})
            showPlot {bool} -- Decide to show the plot or not. (default: {False})

        Returns:
            [float] -- Coefficients of Zernike polynomials (z4 - z22).
        """

        # Check the image size
        for img in (self.ImgIntra, self.ImgExtra):
            d1, d2 = img.image.shape
            if (d1 != self.sizeInPix) or (d2 != self.sizeInPix):
                raise RuntimeError(
                    "Input image shape is (%d, %d), not required (%d, %d)" %
                    (d1, d2, self.sizeInPix, self.sizeInPix))

        # Calculate the wavefront error.
        # Run cwfs
        self.algo.runIt(self.inst,
                        self.ImgIntra,
                        self.ImgExtra,
                        self.opticalModel,
                        tol=tol)

        # Show the Zernikes Zn (n>=4)
        if (showZer):
            self.algo.outZer4Up(showPlot=showPlot)

        return self.algo.zer4UpNm

    def outParam(self, filename=None):
        """

        Put the information of images, instrument, and algorithm on terminal or file.

        Keyword Arguments:
            filename {[str]} -- Name of output file. (default: {None})
        """

        # Write the parameters into a file if needed.
        if (filename is not None):
            fout = open(filename, "w")
        else:
            fout = sys.stdout

        # Write the information of image and optical model
        if (self.ImgIntra.name is not None):
            fout.write("Intra image: \t %s\n" % self.ImgIntra.name)

        if (self.ImgIntra.fieldX is not None):
            fout.write("Intra image field in deg =(%6.3f, %6.3f)\n" %
                       (self.ImgIntra.fieldX, self.ImgIntra.fieldY))

        if (self.ImgExtra.name is not None):
            fout.write("Extra image: \t %s\n" % self.ImgExtra.name)

        if (self.ImgExtra.fieldX is not None):
            fout.write("Extra image field in deg =(%6.3f, %6.3f)\n" %
                       (self.ImgExtra.fieldX, self.ImgExtra.fieldY))

        if (self.opticalModel is not None):
            fout.write("Using optical model:\t %s\n" % self.opticalModel)

        # Read the instrument file
        if (self.inst.filename is not None):
            self.__readConfigFile(fout, self.inst, "instrument")

        # Read the algorithm file
        if (self.algo.filename is not None):
            self.__readConfigFile(fout, self.algo, "algorithm")

        # Close the file
        if (filename is not None):
            fout.close()

    def __readConfigFile(self, fout, config, configName):
        """
        
        Read the configuration file
        
        Arguments:
            fout {[file]} -- File instance.
            config {[metadata]} -- Instance of configuration. It is Instrument or Algorithm 
                                   here.
            configName {[str]} -- Name of configuration.
        """

        # Create a new line
        fout.write("\n")

        # Open the file
        fconfig = open(config.filename)
        fout.write("---" + configName +
                   " file: --- %s ----------\n" % config.filename)

        # Read the file information
        iscomment = False
        for line in fconfig:
            line = line.strip()
            if (line.startswith("###")):
                iscomment = ~iscomment
            if (not (line.startswith("#")) and (not iscomment)
                    and len(line) > 0):
                fout.write(line + "\n")

        # Close the file
        fconfig.close()