コード例 #1
0
ファイル: thresher.py プロジェクト: kcompher/TheThresher
    def __init__(self, initial, image_list, mask_list=None, invert=False,
            square=False, outdir="", centers=None, psf_hw=13, kernel=None,
            psfreg=0., sceneL2=0.0, dc=0.0, light=False, hdu=0):
        # Metadata.
        self.image_list = image_list
        if mask_list is not None:
            self.mask_list = dict([(k, mask_list[i])
                                    for i, k in enumerate(image_list)])
        else:
            self.mask_list = {}
        self.invert = invert
        self.square = square
        self.outdir = os.path.abspath(outdir)
        self.psf_hw = psf_hw
        self.psfreg = psfreg
        self.sceneL2 = sceneL2
        self.dc = dc
        self.light = False
        self.hdu = hdu

        # Sort out the center vector and save it as a dictionary associated
        # with specific filenames.
        self.centers = centers
        if centers is not None:
            self.centers = dict([(image_list[i], centers[i])
                for i in range(len(image_list))])

        # Inference parameters.
        self.sky = 0
        self.scene = np.array(initial)

        # 'Sky'-subtract the initial scene.
        self.scene -= np.median(self.scene)

        # Deal with the masked pixels if there are any in the initial scene
        # by setting them to the 'sky' level.
        self.scene[np.isnan(self.scene)] = 0.0

        # Check the dimensions of the initial scene and set the size.
        shape = self.scene.shape
        assert shape[0] == shape[1], "The initial scene needs to be square."
        self.size = shape[0] - 2 * self.psf_hw

        # The 'kernel' used for 'light deconvolution'.
        if kernel is None:
            self.kernel = np.exp(-0.5 * (np.arange(-5, 6)[:, None] ** 2
                + np.arange(-5, 6)[None, :] ** 2))
            self.kernel /= np.sum(self.kernel)
        else:
            self.kernel = kernel

        # Index gymnastics.
        self.scene_mask = utils.unravel_scene(self.size + 2 * self.psf_hw,
                self.psf_hw)
        self.psf_rows, self.psf_cols = \
                utils.unravel_psf(self.size + 2 * self.psf_hw, self.psf_hw)
コード例 #2
0
ファイル: tests.py プロジェクト: davidwhogg/TheThresher
    def test_unravel_psf(self):
        """
        Test to make sure that the PSF unraveling yields the correct
        results.

        """
        S, P = 4, 1
        rows, cols = utils.unravel_psf(S, P)

        # Calculate the brute force unraveled scene.
        b_cols = np.array(
            [
                0,
                1,
                2,
                4,
                5,
                6,
                8,
                9,
                10,
                1,
                2,
                3,
                5,
                6,
                7,
                9,
                10,
                11,
                4,
                5,
                6,
                8,
                9,
                10,
                12,
                13,
                14,
                5,
                6,
                7,
                9,
                10,
                11,
                13,
                14,
                15,
            ],
            dtype=int,
        )
        b_rows = np.concatenate([k * np.ones((2 * P + 1) ** 2, dtype=int) for k in range(4)])

        assert np.all(rows == b_rows) and np.all(cols == b_cols)
コード例 #3
0
    def test_unravel_psf(self):
        """
        Test to make sure that the PSF unraveling yields the correct
        results.

        """
        S, P = 4, 1
        rows, cols = utils.unravel_psf(S, P)

        # Calculate the brute force unraveled scene.
        b_cols = np.array([
            0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11, 4, 5, 6,
            8, 9, 10, 12, 13, 14, 5, 6, 7, 9, 10, 11, 13, 14, 15
        ],
                          dtype=int)
        b_rows = np.concatenate(
            [k * np.ones((2 * P + 1)**2, dtype=int) for k in range(4)])

        assert np.all(rows == b_rows) and np.all(cols == b_cols)