Example #1
0
    def __init__(self, nhex=15, narr=200, extrasupportindex=None):
        """
		Input: nzern: number of Noll Zernikes to use in the fit
		Input: narr: the live pupil array size you want to use
		
		Sets up list of poly's and grids & support grids
		Makes coordinate grid for rho and phi and circular support mask
		Calculates 'overlap integrals' (covariance matrix) of the Zernike polynomials on your grid and array size
		Calculates the inverse of this matrix, so it's 'ready to fit' your incoming array
		
		"""
        self.narr = narr
        self.nhex = nhex  # tbd - allowed numbers from Pascal's Triangle sum(n) starting from n=1, viz. n(n+1)/2
        self.grid = (N.indices((self.narr, self.narr), dtype=N.float) -
                     self.narr // 2) / (float(self.narr) * 0.5)
        self.grid_rho = (self.grid**2.0).sum(0)**0.5
        self.grid_phi = N.arctan2(self.grid[0], self.grid[1])
        self.grid_mask = self.grid_rho <= 1
        self.grid_outside = self.grid_rho > 1

        if extrasupportindex is not None:
            self.grid_mask[extrasupportindex] = 0
            self.grid_outside[extrasupportindex] = 1

        # Compute list of explicit Zernike polynomials and keep them around for fitting
        self.hex_list = z.hexike_basis(nterms=self.nhex, npix=self.narr)

        self.hex_list = z.hexike_basis(nterms=self.nhex, npix=self.narr)

        # Force hexikes to be unit standard deviation over hex mask
        for h, hfunc in enumerate(self.hex_list):
            if h > 0:
                self.hex_list[h] = (
                    hfunc / hfunc[self.grid_mask].std()) * self.grid_mask
            else:
                self.hex_list[0] = hfunc * self.grid_mask

        #### Write out a cube of the hexikes in here
        if 0:
            self.stack = N.zeros((self.nhex, self.narr, self.narr))
            for ii in range(len(self.hex_list)):
                self.stack[ii, :, :] = self.hex_list[ii]
        # Calculate covariance between all Zernike polynomials
        self.cov_mat = N.array([[N.sum(hexi * hexj) for hexi in self.hex_list]
                                for hexj in self.hex_list])
        self.grid_mask = z.hex_aperture(npix=self.narr) == 1
        self.grid_outside = self.grid_mask == False
        # Invert covariance matrix using SVD
        self.cov_mat_in = N.linalg.pinv(self.cov_mat)
Example #2
0
def _test_cross_hexikes(testj=4, nterms=10, npix=500):
    """Verify the functions are orthogonal, by taking the
    integrals of a given Hexike times N other ones.

    This is a helper function for test_cross_hexike.

    Parameters :
    --------------
    testj : int
        Index of the Zernike polynomial to test against the others
    nterms : int
        Test that polynomial against those from 1 to this N
    npix : int
        Size of array to use for this test
    """

    hexike_basis = zernike.hexike_basis(nterms=nterms, npix=npix)
    test_hexike = hexike_basis[testj - 1]
    assert np.sum(
        np.isfinite(test_hexike)) > 0, "Hexike calculation failure; all NaNs."
    for idx, hexike_array in enumerate(hexike_basis):
        j = idx + 1
        if j == testj or j == 1:
            continue  # discard piston term and self
        prod = hexike_array * test_hexike
        wg = np.where(np.isfinite(prod))
        cross_sum = np.abs(prod[wg].sum())

        # Threshold was originally 1e-9, but we ended up getting 1.19e-9 on some machines (not always)
        # this seems acceptable, so relaxing criteria slightly
        assert cross_sum < 2e-9, (
            "orthogonality failure, Sum[Hexike(j={}) * Hexike(j={})] = {} (> 2e-9)"
            .format(j, testj, cross_sum))
Example #3
0
    def fit_hexikes_to_surface(self, surface, choosemodes=False):
        """
		Input: surface: input surface to be fit (2D array)
		Output: zcoeffs: 1d vector of coefficients of the fit (self.nzern in length)
		Output: rec_wf: the 'recovered wavefront' - i.e. the fitted zernikes, in same array size as surface
		Output: res_wf: surface - rec_wf, i.e. the residual error in the fit

		"""

        # Calculate the inner product of each Zernike mode with the test surface
        wf_hex_inprod = N.array(
            [N.sum(surface * hexi) for hexi in self.hex_list])

        # Given the inner product vector of the test wavefront with Zernike basis,
        # calculate the Zernike polynomial coefficients
        hcoeffs = N.dot(self.cov_mat_in, wf_hex_inprod)

        # Reconstruct (e.g. wavefront) surface from Zernike components
        hexikes = z.hexike_basis(nterms=len(hcoeffs), npix=self.narr)
        if choosemodes is not False:
            hcoeffs[choosemodes == 0] = 0
        rec_wf = sum(val * hexikes[i] for (i, val) in enumerate(hcoeffs))

        if 0:
            print "First 10 recovered Hernike coeffts:", hcoeffs[:10]
            print "Standard deviation of fit is %.3e" % (
                surface * self.grid_mask - rec_wf)[self.grid_mask].std()
        return hcoeffs, rec_wf, (surface - rec_wf) * self.grid_mask
Example #4
0
def _test_cross_hexikes(testj=4, nterms=10, npix=500):
    """Verify the functions are orthogonal, by taking the
    integrals of a given Hexike times N other ones.

    Parameters :
    --------------
    testj : int
        Index of the Zernike polynomial to test against the others
    nterms : int
        Test that polynomial against those from 1 to this N
    npix : int
        Size of array to use for this test
    """

    hexike_basis = zernike.hexike_basis(nterms=nterms, npix=npix)
    test_hexike = hexike_basis[testj - 1]
    for idx, hexike_array in enumerate(hexike_basis):
        j = idx + 1
        if j == testj or j == 1:
            continue  # discard piston term and self
        prod = hexike_array * test_hexike
        wg = np.where(np.isfinite(prod))
        cross_sum = np.abs(prod[wg].sum())
        assert cross_sum < 1e-9, (
            "orthogonality failure, Sum[Hexike(j={}) * Hexike(j={})] = {} (> 1e-9)".format(
                j, testj, cross_sum)
        )
Example #5
0
def _test_cross_hexikes(testj=4, nterms=10, npix=500):
    """Verify the functions are orthogonal, by taking the
    integrals of a given Hexike times N other ones.

    This is a helper function for test_cross_hexike.

    Parameters :
    --------------
    testj : int
        Index of the Zernike polynomial to test against the others
    nterms : int
        Test that polynomial against those from 1 to this N
    npix : int
        Size of array to use for this test
    """

    hexike_basis = zernike.hexike_basis(nterms=nterms, npix=npix)
    test_hexike = hexike_basis[testj - 1]
    assert np.sum(np.isfinite(test_hexike)) > 0, "Hexike calculation failure; all NaNs."
    for idx, hexike_array in enumerate(hexike_basis):
        j = idx + 1
        if j == testj or j == 1:
            continue  # discard piston term and self
        prod = hexike_array * test_hexike
        wg = np.where(np.isfinite(prod))
        cross_sum = np.abs(prod[wg].sum())

        # Threshold was originally 1e-9, but we ended up getting 1.19e-9 on some machines (not always)
        # this seems acceptable, so relaxing criteria slightly
        assert cross_sum < 2e-9, (
            "orthogonality failure, Sum[Hexike(j={}) * Hexike(j={})] = {} (> 2e-9)".format(
                j, testj, cross_sum)
        )
Example #6
0
def projectionfilter(data,
                     nterms=None,
                     bases=None,
                     npix=None,
                     basis_type='Zernike',
                     outside='nan',
                     basis_kwds={}):
    """ Filtering reconstructed band structure using orthogonal polynomial approximation.

    **Parameters**\n
    data: 2D array
        Band dispersion in 2D to filter.
    nterms: int | None
        Number of terms.
    bases: 3D array | None
        Bases for decomposition.
    npix: int | None
        Size (number of pixels) in one direction of each basis term.
    basis_type: str | 'Zernike'
        Type of basis to use for filtering.
    outside: numeric/str | 'nan'
        Values to fill for regions outside the Brillouin zone boundary.
    basis_kwds: dictionary | {}
        Keywords for basis generator (see `poppy.zernike.hexike_basis()` if hexagonal Zernike polynomials are used).
    """

    nterms = int(nterms)

    # Generate basis functions
    if bases is None:
        if basis_type == 'Zernike':
            bases = ppz.hexike_basis(nterms=nterms, npix=npix, **basis_kwds)

    # Decompose into the given basis
    coeffs = decomposition_hex2d(data,
                                 bases=bases,
                                 baxis=0,
                                 nterms=nterms,
                                 basis_type=basis_type,
                                 ret='coeffs')

    # Reconstruct the smoothed version of the energy band
    recon = reconstruction_hex2d(coeffs,
                                 bases=bases,
                                 baxis=0,
                                 npix=npix,
                                 basis_type=basis_type,
                                 ret='band')

    if outside == 'nan':
        recon = u.to_masked(recon, val=0)
        return recon, coeffs
    elif outside == 0:
        return recon, coeffs
Example #7
0
def decomposition_hex2d(band,
                        bases=None,
                        baxis=0,
                        nterms=100,
                        basis_type='Zernike',
                        ret='coeffs'):
    """ Decompose energy band in 3D momentum space using the orthogonal polynomials in a hexagon.

    **Parameters**\n
    band: 2D array
        2D electronic band structure.
    bases: 3D array | None
        Matrix composed of bases to decompose into.
    baxis: int | 0
        Axis of the basis index.
    nterms: int | 100
        Number of basis terms.
    basis_type: str | 'Zernike'
        Type of basis to use.
    ret: str | 'coeffs'
        Options for the return values.
    """

    nbr, nbc = band.shape
    if nbr != nbc:
        raise ValueError('Input band surface should be square!')

    if bases is None:
        if basis_type == 'Zernike':
            bases = ppz.hexike_basis(nterms=nterms,
                                     npix=nbr,
                                     vertical=True,
                                     outside=0)
        elif basis_type == 'Fourier':
            raise NotImplementedError
        else:
            raise NotImplementedError

    else:
        if baxis != 0:
            bases = np.moveaxis(bases, baxis, 0)

    nbas, nbasr, nbasc = bases.shape
    band_flat = band.reshape((band.size, ))
    coeffs = np.linalg.pinv(bases.reshape(
        (nbas, nbasr * nbasc))).T.dot(band_flat)

    if ret == 'coeffs':
        return coeffs
    elif ret == 'all':
        return coeffs, bases
Example #8
0
def reconstruction_hex2d(coeffs,
                         bases=None,
                         baxis=0,
                         npix=256,
                         basis_type='Zernike',
                         ret='band'):
    """ Reconstruction of energy band in 3D momentum space using orthogonal polynomials
    and the term-wise coefficients.

    **Parameters**\n
    coeffs: 1D array
        Polynomial coefficients to use in reconstruction.
    bases: 3D array | None
        Matrix composed of bases to decompose into.
    baxis: int | 0
        Axis of the basis index.
    npix: int | 256
        Number of pixels along one side in the square image.
    basis_type: str | 'Zernike'
        Type of basis to use.
    ret: str | 'band'
        Options for the return values.
    """

    coeffs = coeffs.ravel()
    nterms = coeffs.size

    if bases is None:
        if basis_type == 'Zernike':
            bases = ppz.hexike_basis(nterms=nterms,
                                     npix=npix,
                                     vertical=True,
                                     outside=0)
        elif basis_type == 'Fourier':
            raise NotImplementedError
        else:
            raise NotImplementedError

    else:
        if baxis != 0:
            bases = np.moveaxis(bases, baxis, 0)

    nbas, nbasr, nbasc = bases.shape
    band_recon = bases.reshape((nbas, nbasr * nbasc)).T.dot(coeffs).reshape(
        (nbasr, nbasc))

    if ret == 'band':
        return band_recon
    elif ret == 'all':
        return band_recon, bases
Example #9
0
    def basisgen(self,
                 nterms,
                 npix,
                 vertical=True,
                 outside=0,
                 basis_type='Zernike'):
        """ Generate polynomial bases for energy band synthesis.
        """

        if basis_type == 'Zernike':
            self.bases = ppz.hexike_basis(nterms=nterms,
                                          npix=npix,
                                          vertical=vertical,
                                          outside=outside)
Example #10
0
def hexmask(hexdiag=128,
            imside=256,
            image=None,
            padded=False,
            margins=[],
            pad_top=None,
            pad_bottom=None,
            pad_left=None,
            pad_right=None,
            vertical=True,
            outside='nan',
            ret='mask',
            **kwargs):
    """ Generate a hexagonal mask. To use the function, either the argument ``imside`` or ``image`` should be
    given. The image padding on four sides could be specified with either ``margins`` altogether or separately
    with the individual arguments ``pad_xxx``. For the latter, at least two independent padding values are needed.

    **Parameters**\n
    hexdiag: int | 128
        Number of pixels along the hexagon's diagonal.
    imside: int | 256
        Number of pixels along the side of the (square) reference image.
    image: 2D array | None
        2D reference image to construct the mask for. If the reference (image) is given, each side
        of the generated mask is at least that of the smallest dimension of the reference.
    padded: bool | False
        Option to pad the image (need to set to True to enable the margins).
    margins: list/tuple | []
        Margins of the image [top, bottom, left, right]. Overrides the `pad_xxx` arguments.
    pad_top, pad_bottom, pad_left, pad_right : int, int, int, int | None, None, None, None
        Number of padded pixels on each of the four sides of the image.
    vertical: bool | True
        Option to align the diagonal of the hexagon with the vertical image axis.
    outside: numeric/str | 'nan'
        Pixel value outside the masked region.
    ret: str | 'mask'
        Return option ('mask', 'masked_image', 'all').
    """

    if image is not None:
        imshape = image.shape
        minside = min(imshape)
        mask = ppz.hexike_basis(nterms=1, npix=minside, vertical=vertical)[0,
                                                                           ...]
    else:
        imshape = kwargs.pop('imshape', (imside, imside))
        mask = ppz.hexike_basis(nterms=1, npix=hexdiag, vertical=vertical)[0,
                                                                           ...]

    # Use a padded version of the original mask
    if padded == True:

        # Padding image margins on all sides
        if len(margins) == 4:
            top, bottom, left, right = margins

        else:
            # Total padding pixel numbers along horizontal and vertical directions
            padsides = np.abs(np.asarray(imshape) - hexdiag)
            top, bottom = u.nonneg_sum_decomposition(a=pad_top,
                                                     b=pad_bottom,
                                                     absum=padsides[0])
            left, right = u.nonneg_sum_decomposition(a=pad_left,
                                                     b=pad_right,
                                                     absum=padsides[1])

        mask = np.pad(mask, ((top, bottom), (left, right)),
                      mode='constant',
                      constant_values=np.nan)

    if outside == 0:
        mask = np.nan_to_num(mask)

    if ret == 'mask':
        return mask
    elif ret == 'masked_image':
        return mask * image
    elif ret == 'all':
        margins = [top, bottom, left, right]
        return mask, margins
Example #11
0
def analytical_model(zernike_pol, coef, cali=False):
    """

    :param zernike_pol:
    :param coef:
    :param cali: bool; True if we already have calibration coefficients to use. False if we still need to create them.
    :return:
    """

    #-# Parameters
    dataDir = os.path.join(CONFIG_INI.get('local', 'local_data_path'),
                           'active')
    telescope = CONFIG_INI.get('telescope', 'name')
    nb_seg = CONFIG_INI.getint(telescope, 'nb_subapertures')
    tel_size_m = CONFIG_INI.getfloat(telescope, 'diameter') * u.m
    real_size_seg = CONFIG_INI.getfloat(
        telescope, 'flat_to_flat'
    )  # in m, size in meters of an individual segment flatl to flat
    size_seg = CONFIG_INI.getint(
        'numerical',
        'size_seg')  # pixel size of an individual segment tip to tip
    wvln = CONFIG_INI.getint(telescope, 'lambda') * u.nm
    inner_wa = CONFIG_INI.getint(telescope, 'IWA')
    outer_wa = CONFIG_INI.getint(telescope, 'OWA')
    tel_size_px = CONFIG_INI.getint(
        'numerical', 'tel_size_px')  # pupil diameter of telescope in pixels
    im_size_pastis = CONFIG_INI.getint(
        'numerical', 'im_size_px_pastis')  # image array size in px
    sampling = CONFIG_INI.getfloat('numerical', 'sampling')  # sampling
    size_px_tel = tel_size_m / tel_size_px  # size of one pixel in pupil plane in m
    px_sq_to_rad = (size_px_tel * np.pi / tel_size_m) * u.rad
    zern_max = CONFIG_INI.getint('zernikes', 'max_zern')
    sz = CONFIG_INI.getint('numerical', 'im_size_lamD_hcipy')

    # Create Zernike mode object for easier handling
    zern_mode = util.ZernikeMode(zernike_pol)

    #-# Mean subtraction for piston
    if zernike_pol == 1:
        coef -= np.mean(coef)

    #-# Generic segment shapes

    if telescope == 'JWST':
        # Load pupil from file
        pupil = fits.getdata(
            os.path.join(dataDir, 'segmentation', 'pupil.fits'))

        # Put pupil in randomly picked, slightly larger image array
        pup_im = np.copy(pupil)  # remove if lines below this are active
        #pup_im = np.zeros([tel_size_px, tel_size_px])
        #lim = int((pup_im.shape[1] - pupil.shape[1])/2.)
        #pup_im[lim:-lim, lim:-lim] = pupil
        # test_seg = pupil[394:,197:315]    # this is just so that I can display an individual segment when the pupil is 512
        # test_seg = pupil[:203,392:631]    # ... when the pupil is 1024
        # one_seg = np.zeros_like(test_seg)
        # one_seg[:110, :] = test_seg[8:, :]    # this is the centered version of the individual segment for 512 px pupil

        # Creat a mini-segment (one individual segment from the segmented aperture)
        mini_seg_real = poppy.NgonAperture(
            name='mini', radius=real_size_seg
        )  # creating real mini segment shape with poppy
        #test = mini_seg_real.sample(wavelength=wvln, grid_size=flat_diam, return_scale=True)   # fix its sampling with wavelength
        mini_hdu = mini_seg_real.to_fits(wavelength=wvln,
                                         npix=size_seg)  # make it a fits file
        mini_seg = mini_hdu[
            0].data  # extract the image data from the fits file

    elif telescope == 'ATLAST':
        # Create mini-segment
        pupil_grid = hcipy.make_pupil_grid(dims=tel_size_px,
                                           diameter=real_size_seg)
        focal_grid = hcipy.make_focal_grid(
            pupil_grid, sampling, sz, wavelength=wvln.to(
                u.m).value)  # fov = lambda/D radius of total image
        prop = hcipy.FraunhoferPropagator(pupil_grid, focal_grid)

        mini_seg_real = hcipy.hexagonal_aperture(circum_diameter=real_size_seg,
                                                 angle=np.pi / 2)
        mini_seg_hc = hcipy.evaluate_supersampled(
            mini_seg_real, pupil_grid, 4
        )  # the supersampling number doesn't really matter in context with the other numbers
        mini_seg = mini_seg_hc.shaped  # make it a 2D array

        # Redefine size_seg if using HCIPy
        size_seg = mini_seg.shape[0]

        # Make stand-in pupil for DH array
        pupil = fits.getdata(
            os.path.join(dataDir, 'segmentation', 'pupil.fits'))
        pup_im = np.copy(pupil)

    #-# Generate a dark hole mask
    #TODO: simplify DH generation and usage
    dh_area = util.create_dark_hole(
        pup_im, inner_wa, outer_wa, sampling
    )  # this might become a problem if pupil size is not same like pastis image size. fine for now though.
    if telescope == 'ATLAST':
        dh_sz = util.zoom_cen(dh_area, sz * sampling)

    #-# Import information form segmentation script
    Projection_Matrix = fits.getdata(
        os.path.join(dataDir, 'segmentation', 'Projection_Matrix.fits'))
    vec_list = fits.getdata(
        os.path.join(dataDir, 'segmentation', 'vec_list.fits'))  # in pixels
    NR_pairs_list = fits.getdata(
        os.path.join(dataDir, 'segmentation', 'NR_pairs_list_int.fits'))

    # Figure out how many NRPs we're dealing with
    NR_pairs_nb = NR_pairs_list.shape[0]

    #-# Chose whether calibration factors to do the calibraiton with
    if cali:
        filename = 'calibration_' + zern_mode.name + '_' + zern_mode.convention + str(
            zern_mode.index)
        ck = fits.getdata(
            os.path.join(dataDir, 'calibration', filename + '.fits'))
    else:
        ck = np.ones(nb_seg)

    coef = coef * ck

    #-# Generic coefficients
    # the coefficients in front of the non redundant pairs, the A_q in eq. 13 in Leboulleux et al. 2018
    generic_coef = np.zeros(
        NR_pairs_nb
    ) * u.nm * u.nm  # setting it up with the correct units this will have

    for q in range(NR_pairs_nb):
        for i in range(nb_seg):
            for j in range(i + 1, nb_seg):
                if Projection_Matrix[i, j, 0] == q + 1:
                    generic_coef[q] += coef[i] * coef[j]

    #-# Constant sum and cosine sum - calculating eq. 13 from Leboulleux et al. 2018
    if telescope == 'JWST':
        i_line = np.linspace(-im_size_pastis / 2., im_size_pastis / 2.,
                             im_size_pastis)
        tab_i, tab_j = np.meshgrid(i_line, i_line)
        cos_u_mat = np.zeros(
            (int(im_size_pastis), int(im_size_pastis), NR_pairs_nb))
    elif telescope == 'ATLAST':
        i_line = np.linspace(-(2 * sz * sampling) / 2.,
                             (2 * sz * sampling) / 2., (2 * sz * sampling))
        tab_i, tab_j = np.meshgrid(i_line, i_line)
        cos_u_mat = np.zeros((int((2 * sz * sampling)), int(
            (2 * sz * sampling)), NR_pairs_nb))

    # Calculating the cosine terms from eq. 13.
    # The -1 with each NR_pairs_list is because the segment names are saved starting from 1, but Python starts
    # its indexing at zero, so we have to make it start at zero here too.
    for q in range(NR_pairs_nb):
        # cos(b_q <dot> u): b_q with 1 <= q <= NR_pairs_nb is the basis of NRPS, meaning the distance vectors between
        #                   two segments of one NRP. We can read these out from vec_list.
        #                   u is the position (vector) in the detector plane. Here, those are the grids tab_i and tab_j.
        # We need to calculate the dot product between all b_q and u, so in each iteration (for q), we simply add the
        # x and y component.
        cos_u_mat[:, :, q] = np.cos(
            px_sq_to_rad *
            (vec_list[NR_pairs_list[q, 0] - 1, NR_pairs_list[q, 1] - 1, 0] *
             tab_i) + px_sq_to_rad *
            (vec_list[NR_pairs_list[q, 0] - 1, NR_pairs_list[q, 1] - 1, 1] *
             tab_j)) * u.dimensionless_unscaled

    sum1 = np.sum(
        coef**2
    )  # sum of all a_{k,l} in eq. 13 - this works only for single Zernikes (l fixed), because np.sum would sum over l too, which would be wrong.
    if telescope == 'JWST':
        sum2 = np.zeros(
            (int(im_size_pastis), int(im_size_pastis))
        ) * u.nm * u.nm  # setting it up with the correct units this will have
    elif telescope == 'ATLAST':
        sum2 = np.zeros(
            (int(2 * sz * sampling), int(2 * sz * sampling))) * u.nm * u.nm

    for q in range(NR_pairs_nb):
        sum2 = sum2 + generic_coef[q] * cos_u_mat[:, :, q]

    #-# Local Zernike
    if telescope == 'JWST':
        # Generate a basis of Zernikes with the mini segment being the support
        isolated_zerns = zern.hexike_basis(nterms=zern_max,
                                           npix=size_seg,
                                           rho=None,
                                           theta=None,
                                           vertical=False,
                                           outside=0.0)

        # Calculate the Zernike that is currently being used and put it on one single subaperture, the result is Zer
        # Apply the currently used Zernike to the mini-segment.
        if zernike_pol == 1:
            Zer = np.copy(mini_seg)
        elif zernike_pol in range(2, zern_max - 2):
            Zer = np.copy(mini_seg)
            Zer = Zer * isolated_zerns[zernike_pol - 1]

        # Fourier Transform of the Zernike - the global envelope
        mf = mft.MatrixFourierTransform()
        ft_zern = mf.perform(Zer, im_size_pastis / sampling, im_size_pastis)

    elif telescope == 'ATLAST':
        isolated_zerns = hcipy.make_zernike_basis(num_modes=zern_max,
                                                  D=real_size_seg,
                                                  grid=pupil_grid,
                                                  radial_cutoff=False)
        Zer = hcipy.Wavefront(mini_seg_hc * isolated_zerns[zernike_pol - 1],
                              wavelength=wvln.to(u.m).value)

        # Fourier transform the Zernike
        ft_zern = prop(Zer)

    #-# Final image
    if telescope == 'JWST':
        # Generating the final image that will get passed on to the outer scope, I(u) in eq. 13
        intensity = np.abs(ft_zern)**2 * (sum1.value + 2. * sum2.value)
    elif telescope == 'ATLAST':
        intensity = ft_zern.intensity.shaped * (sum1.value + 2. * sum2.value)

    # PASTIS is only valid inside the dark hole, so we cut out only that part
    if telescope == 'JWST':
        tot_dh_im_size = sampling * (outer_wa + 3)
        intensity_zoom = util.zoom_cen(
            intensity, tot_dh_im_size
        )  # zoom box is (owa + 3*lambda/D) wide, in terms of lambda/D
        dh_area_zoom = util.zoom_cen(dh_area, tot_dh_im_size)

        dh_psf = dh_area_zoom * intensity_zoom

    elif telescope == 'ATLAST':
        dh_psf = dh_sz * intensity
    """
    # Create plots.
    plt.subplot(1, 3, 1)
    plt.imshow(pupil, origin='lower')
    plt.title('JWST pupil and diameter definition')
    plt.plot([46.5, 464.5], [101.5, 409.5], 'r-')   # show how the diagonal of the pupil is defined

    plt.subplot(1, 3, 2)
    plt.imshow(mini_seg, origin='lower')
    plt.title('JWST individual mini-segment')

    plt.subplot(1, 3, 3)
    plt.imshow(dh_psf, origin='lower')
    plt.title('JWST dark hole')
    plt.show()
    """

    # dh_psf is the image of the dark hole only, the pixels outside of it are zero
    # intensity is the entire final image
    return dh_psf, intensity