コード例 #1
0
    def loadCalibration(self):

        coef, domain = np.load(self.calibrationfilename)
        self.pixelstowavelengths = Legendre(coef, domain)
        self.polynomialdegree = self.pixelstowavelengths.degree()
        self.speak("loaded wavelength calibration"
                "from {0}".format(self.calibrationfilename))
コード例 #2
0
def subtract_legendre_fit(array2d: np.ndarray,
                          keep_offset: bool = False,
                          deg: int = 1) -> Optional[np.ndarray]:
    """
    Use a legendre polynomial fit of degree legendre_deg in X and Y direction to correct background.
    legendre_deg = 0 ... subtract mean value
    legendre_deg = 1 ... subtract mean plane
    legendre_deg = 2 ... subtract simple curved mean surface
    legendre_deg = 3 ... also corrects "s-shaped" distortion
    ...
    """
    if deg == 0 and keep_offset:
        return array2d.copy()  # return a copy of input data
    n_row = np.linspace(-1, 1, array2d.shape[0])
    n_col = np.linspace(-1, 1, array2d.shape[1])
    mean_row = array2d.mean(axis=1)
    mean_col = array2d.mean(axis=0)

    fit_x = Legendre.fit(n_row, mean_row, deg)
    fit_y = Legendre.fit(n_col, mean_col, deg)

    result = array2d.copy()
    result = (result.transpose() -
              np.polynomial.legendre.legval(n_row, fit_x.coef)).transpose()
    result = result - np.polynomial.legendre.legval(n_col, fit_y.coef)
    if keep_offset:
        result = result + 2 * array2d.mean(
        )  # mean was subtracted 2 times (once for fit_x ans once for fit_y)
    else:
        result = result + array2d.mean()
    return result
コード例 #3
0
 def basis(self, type, M, X_star=None):
     if type is 'monomial':
         if X_star is None:
             Phi = np.zeros((np.shape(self.X)[0], (M + 1)))
             for c in range(0, np.shape(Phi)[1]):
                 Phi[:, c] = self.X[:, 0].T**c
         else:
             Phi = np.zeros((np.shape(X_star)[0], (M + 1)))
             for c in range(0, np.shape(Phi)[1]):
                 Phi[:, c] = X_star[:, 0].T**c
     if type is 'fourier':
         if X_star is None:
             Phi = np.zeros((np.shape(self.X)[0], (M + 1) * 2))
             for c in range(0, np.shape(Phi)[1], 2):
                 Phi[:, c] = np.sin(int(c / 2) * np.pi * self.X[:, 0].T)
                 Phi[:, c + 1] = np.cos(int(c / 2) * np.pi * self.X[:, 0].T)
         else:
             Phi = np.zeros((np.shape(X_star)[0], (M + 1) * 2))
             for c in range(0, np.shape(Phi)[1], 2):
                 Phi[:, c] = np.sin(int(c / 2) * np.pi * X_star[:, 0].T)
                 Phi[:, c + 1] = np.cos(int(c / 2) * np.pi * X_star[:, 0].T)
     if type is 'legendre':
         from numpy.polynomial import Legendre
         if X_star is None:
             Phi = Legendre.basis(M)(self.X)
         else:
             Phi = Legendre.basis(M)(X_star)
     return Phi
コード例 #4
0
 def loadCalibration(self):
     '''Load the wavelength calibration.'''
     self.speak('trying to load calibration data')
     coef, domain = np.load(self.calibrationfilename, allow_pickle=True)
     self.pixelstowavelengths = Legendre(coef, domain)
     self.polynomialdegree = self.pixelstowavelengths.degree()
     self.speak("loaded wavelength calibration"
                "from {0}".format(self.calibrationfilename))
コード例 #5
0
def dg_interior_flux_matrix(p):
    """Calculate the interior flux matrix for the standard DG advection equations

    There is an analytical expression I could use for this too.
    See my thesis page 74.
    """
    F = np.zeros((p + 1, p + 1))
    for i in range(0, p + 1):
        dli = L.basis(i).deriv()
        for j in range(0, p + 1):
            lj = L.basis(j)
            F[i, j] = basis.integrate_legendre_product(dli, lj)

    return F
コード例 #6
0
ファイル: dg_stability.py プロジェクト: marchdf/dg1d
def dg_interior_flux_matrix(p):
    """Calculate the interior flux matrix for the standard DG advection equations

    There is an analytical expression I could use for this too.
    See my thesis page 74.
    """
    F = np.zeros((p + 1, p + 1))
    for i in range(0, p + 1):
        dli = L.basis(i).deriv()
        for j in range(0, p + 1):
            lj = L.basis(j)
            F[i, j] = basis.integrate_legendre_product(dli, lj)

    return F
コード例 #7
0
ファイル: tracepol.py プロジェクト: AntoineDarveau/jwst-mtl
def wavelength_to_pix(wavelength,
                      tracepars,
                      m=1,
                      frame='dms',
                      subarray='SUBSTRIP256',
                      oversample=1):
    """Convert wavelength to pixel coordinates for order m.

    :param wavelength: wavelength values in microns.
    :param tracepars: the trace polynomial solutions returned by get_tracepars.
    :param m: the spectral order.
    :param frame: the coordinate frame of the output coordinates (nat, dms or sim).
    :param subarray: the subarray of the output coordinates (SUBARRAY256 or SUBARRAY96).
    :param oversample: the oversampling factor of the outpur coordinates.

    :type wavelength: array[float]
    :type tracepars: dict
    :type m: int
    :type frame: str
    :type subarray: str
    :type oversample: int

    :returns: specpix - the spectral pixel coordinates, spatpix - the spatial pixel coordinates,
    mask - an array that is True when the specpix values were within the valid range of the polynomial.
    :rtype: Tuple(array[float], array[float], array[bool])
    """

    # Convert wavelenght to nat pixel coordinates.
    w2spec = Legendre(tracepars[m]['spec_coef'],
                      domain=tracepars[m]['spec_domain'])
    w2spat = Legendre(tracepars[m]['spat_coef'],
                      domain=tracepars[m]['spat_domain'])

    specpix_nat = w2spec(np.log(wavelength))
    spatpix_nat = w2spat(np.log(wavelength))
    mask = bounds_check(np.log(wavelength), tracepars[m]['spec_domain'][0],
                        tracepars[m]['spec_domain'][1])

    # Convert coordinates to the requested frame.
    specpix, spatpix = pix_ref_to_frame(specpix_nat,
                                        spatpix_nat,
                                        frame=frame,
                                        subarray=subarray)

    # Oversample the coordinates.
    specpix = specpix * oversample
    spatpix = spatpix * oversample

    return specpix, spatpix, mask
コード例 #8
0
def icb_interface_flux_matrix(p, K, T):
    """The interface flux matrices for the ICB schemes, we use upwinding
     to get the fluxes.  Denote T the translation necessary to
     evaluate the flux from the other cell.

    """

    # Enhanced polynomial degree
    phat = p + len(K)

    G0 = smp.zeros(p + 1, phat + 1)
    G1 = smp.zeros(p + 1, phat + 1)
    for i in range(0, p + 1):
        for j in range(0, phat + 1):
            li = L.basis(i)
            lj = L.basis(j)
            G0[i, j] = leg.legval(-1, li.coef) * \
                leg.legval(1, lj.coef) * (T**-1)
            G1[i, j] = leg.legval(1, li.coef) * leg.legval(1, lj.coef)

    # Enhancement matrix
    A, Ainv, B, Binv = enhance.enhancement_matrices(p, K)

    # Using the enhanced function in the flux (see notes 21/4/15)
    BL = smp.zeros(phat + 1, phat + 1)
    BR = smp.zeros(phat + 1, phat + 1)
    for i in range(p + 1):
        li = L.basis(i)
        BL[i, i] = 1  # basis.integrate_legendre_product(li,li)
        BR[i, i] = 1  # BL[i,i]
    for i, k in enumerate(K):
        lk = L.basis(k)
        int_lklk = basis.integrate_legendre_product(lk, lk)
        BL[i + p + 1, i + p + 1] = T  # * int_lklk
        BR[i + p + 1, i + p + 1] = (T**(-1))  # * int_lklk

    # reduction matrix
    R = smp.zeros(phat + 1, p + 1)
    for i in range(p + 1):
        R[i, i] = 1
    for i, k in enumerate(K):
        for j in range(p + 1):
            R[i + p + 1, j] = auxf.delta(k, j)

    # Convert to sympy matrices
    G0 = smp.Matrix(G0) * smp.Matrix(Ainv) * BL * R
    G1 = smp.Matrix(G1) * smp.Matrix(Ainv) * BL * R

    return G0, G1
コード例 #9
0
ファイル: dg_stability.py プロジェクト: marchdf/dg1d
def icb_interface_flux_matrix(p, K, T):
    """The interface flux matrices for the ICB schemes, we use upwinding
     to get the fluxes.  Denote T the translation necessary to
     evaluate the flux from the other cell.

    """

    # Enhanced polynomial degree
    phat = p + len(K)

    G0 = smp.zeros(p + 1, phat + 1)
    G1 = smp.zeros(p + 1, phat + 1)
    for i in range(0, p + 1):
        for j in range(0, phat + 1):
            li = L.basis(i)
            lj = L.basis(j)
            G0[i, j] = leg.legval(-1, li.coef) * \
                leg.legval(1, lj.coef) * (T**-1)
            G1[i, j] = leg.legval(1, li.coef) * leg.legval(1, lj.coef)

    # Enhancement matrix
    A, Ainv, B, Binv = enhance.enhancement_matrices(p, K)

    # Using the enhanced function in the flux (see notes 21/4/15)
    BL = smp.zeros(phat + 1, phat + 1)
    BR = smp.zeros(phat + 1, phat + 1)
    for i in range(p + 1):
        li = L.basis(i)
        BL[i, i] = 1  # basis.integrate_legendre_product(li,li)
        BR[i, i] = 1  # BL[i,i]
    for i, k in enumerate(K):
        lk = L.basis(k)
        int_lklk = basis.integrate_legendre_product(lk, lk)
        BL[i + p + 1, i + p + 1] = T  # * int_lklk
        BR[i + p + 1, i + p + 1] = (T**(-1))  # * int_lklk

    # reduction matrix
    R = smp.zeros(phat + 1, p + 1)
    for i in range(p + 1):
        R[i, i] = 1
    for i, k in enumerate(K):
        for j in range(p + 1):
            R[i + p + 1, j] = auxf.delta(k, j)

    # Convert to sympy matrices
    G0 = smp.Matrix(G0) * smp.Matrix(Ainv) * BL * R
    G1 = smp.Matrix(G1) * smp.Matrix(Ainv) * BL * R

    return G0, G1
コード例 #10
0
def trend(p, t):
    """
    Fit the local transit shape with the following function.

    """
    domain = [t.min(), t.max()]
    return Legendre(p, domain=domain)(t)
コード例 #11
0
ファイル: basis.py プロジェクト: marchdf/dg1d
    def evaluate_basis_gauss(self):
        """Evaluate the basis at the Gaussian quadrature nodes.

        phi will be used to transform Legendre solution coefficients
        to the solution evaluated at the Gaussian quadrature nodes.

        dphi_w will be used for the interior flux integral

        """
        phi = np.zeros((len(self.x), self.N_s))
        dphi_w = np.zeros((len(self.x), self.N_s))

        for n in range(self.N_s):

            # Get the Legendre polynomial of order n and its gradient
            l = L.basis(n)
            dl = l.deriv()

            # Evaluate the basis at the Gaussian nodes
            phi[:, n] = leg.legval(self.x, l.coef)

            # Evaluate the gradient at the Gaussian nodes and multiply by the
            # weights
            dphi_w[n, :] = leg.legval(self.x, dl.coef) * self.w

        return phi, dphi_w
コード例 #12
0
ファイル: basis.py プロジェクト: wo6677/dg1d
    def evaluate_basis_gauss(self):
        """Evaluate the basis at the Gaussian quadrature nodes.

        phi will be used to transform Legendre solution coefficients
        to the solution evaluated at the Gaussian quadrature nodes.

        dphi_w will be used for the interior flux integral

        """
        phi = np.zeros((len(self.x), self.N_s))
        dphi_w = np.zeros((len(self.x), self.N_s))

        for n in range(self.N_s):

            # Get the Legendre polynomial of order n and its gradient
            l = L.basis(n)
            dl = l.deriv()

            # Evaluate the basis at the Gaussian nodes
            phi[:, n] = leg.legval(self.x, l.coef)

            # Evaluate the gradient at the Gaussian nodes and multiply by the
            # weights
            dphi_w[n, :] = leg.legval(self.x, dl.coef) * self.w

        return phi, dphi_w
コード例 #13
0
def regress_poly(degree, data, remove_mean=True, axis=-1):
    ''' returns data with degree polynomial regressed out.
    Be default it is calculated along the last axis (usu. time).
    If remove_mean is True (default), the data is demeaned (i.e. degree 0).
    If remove_mean is false, the data is not.
    '''
    IFLOG.debug('Performing polynomial regression on data of shape ' +
                str(data.shape))

    datashape = data.shape
    timepoints = datashape[axis]

    # Rearrange all voxel-wise time-series in rows
    data = data.reshape((-1, timepoints))

    # Generate design matrix
    X = np.ones((timepoints, 1))  # quick way to calc degree 0
    for i in range(degree):
        polynomial_func = Legendre.basis(i + 1)
        value_array = np.linspace(-1, 1, timepoints)
        X = np.hstack((X, polynomial_func(value_array)[:, np.newaxis]))

    # Calculate coefficients
    betas = np.linalg.pinv(X).dot(data.T)

    # Estimation
    if remove_mean:
        datahat = X.dot(betas).T
    else:  # disregard the first layer of X, which is degree 0
        datahat = X[:, 1:].dot(betas[1:, ...]).T
    regressed_data = data - datahat

    # Back to original shape
    return regressed_data.reshape(datashape)
コード例 #14
0
ファイル: confounds.py プロジェクト: NifTK/nipype
def regress_poly(degree, data, remove_mean=True, axis=-1):
    ''' returns data with degree polynomial regressed out.
    Be default it is calculated along the last axis (usu. time).
    If remove_mean is True (default), the data is demeaned (i.e. degree 0).
    If remove_mean is false, the data is not.
    '''
    IFLOG.debug('Performing polynomial regression on data of shape ' + str(data.shape))

    datashape = data.shape
    timepoints = datashape[axis]

    # Rearrange all voxel-wise time-series in rows
    data = data.reshape((-1, timepoints))

    # Generate design matrix
    X = np.ones((timepoints, 1)) # quick way to calc degree 0
    for i in range(degree):
        polynomial_func = Legendre.basis(i + 1)
        value_array = np.linspace(-1, 1, timepoints)
        X = np.hstack((X, polynomial_func(value_array)[:, np.newaxis]))

    # Calculate coefficients
    betas = np.linalg.pinv(X).dot(data.T)

    # Estimation
    if remove_mean:
        datahat = X.dot(betas).T
    else: # disregard the first layer of X, which is degree 0
        datahat = X[:, 1:].dot(betas[1:, ...]).T
    regressed_data = data - datahat

    # Back to original shape
    return regressed_data.reshape(datashape)
コード例 #15
0
ファイル: enhance.py プロジェクト: wo6677/dg1d
def enhancement_matrices(solution_order, modes):
    """Returns the enhancement matrices (and their inverse)

    Returns A and inv(A) where A \hat{u} = [uL;some_modes_of(uR)]
            B and inv(B) where B \hat{u} = [uR;some_modes_of(uL)]

    Note: this is slightly different than what I do in
    icb_functions.py (called by advection.py) where the right hand
    side contains the normalization factors (i.e A x = b where b =
    uL_i \int \phi_i \phi_i dx). Here I put \int \phi_i \phi_i dx into
    A and B (denoted norm down in the code below).

    """

    # Enhanced solution order
    order = solution_order + len(modes)

    # Submatrices to build the main matrix later
    a = np.diag(np.ones(solution_order + 1))
    b = np.zeros((solution_order + 1, len(modes)))
    cl = np.zeros((len(modes), order + 1))
    cr = np.zeros((len(modes), order + 1))

    # Loop on the modes we are keeping in the neighboring cell
    # (the right cell)
    for i, mode in enumerate(modes):

        # Loop on the enhancement basis
        for j in range(order + 1):

            # Basis function in the right cell
            l1 = L.basis(mode)

            # Enhanced basis function extending into the right cell (or left
            # cell)
            ll = basis.shift_legendre_polynomial(L.basis(j), 2)
            lr = basis.shift_legendre_polynomial(L.basis(j), -2)

            # Inner product for the left and right enhancements
            norm = basis.integrate_legendre_product(l1, l1)
            cl[i, j] = basis.integrate_legendre_product(l1, ll) / norm
            cr[i, j] = basis.integrate_legendre_product(l1, lr) / norm

    # Put the matrices together
    A = np.vstack((np.hstack((a, b)), cl))
    B = np.vstack((np.hstack((a, b)), cr))
    return A, np.linalg.inv(A), B, np.linalg.inv(B)
コード例 #16
0
ファイル: enhance.py プロジェクト: marchdf/dg1d
def enhancement_matrices(solution_order, modes):
    """Returns the enhancement matrices (and their inverse)

    Returns A and inv(A) where A \hat{u} = [uL;some_modes_of(uR)]
            B and inv(B) where B \hat{u} = [uR;some_modes_of(uL)]

    Note: this is slightly different than what I do in
    icb_functions.py (called by advection.py) where the right hand
    side contains the normalization factors (i.e A x = b where b =
    uL_i \int \phi_i \phi_i dx). Here I put \int \phi_i \phi_i dx into
    A and B (denoted norm down in the code below).

    """

    # Enhanced solution order
    order = solution_order + len(modes)

    # Submatrices to build the main matrix later
    a = np.diag(np.ones(solution_order + 1))
    b = np.zeros((solution_order + 1, len(modes)))
    cl = np.zeros((len(modes), order + 1))
    cr = np.zeros((len(modes), order + 1))

    # Loop on the modes we are keeping in the neighboring cell
    # (the right cell)
    for i, mode in enumerate(modes):

        # Loop on the enhancement basis
        for j in range(order + 1):

            # Basis function in the right cell
            l1 = L.basis(mode)

            # Enhanced basis function extending into the right cell (or left
            # cell)
            ll = basis.shift_legendre_polynomial(L.basis(j), 2)
            lr = basis.shift_legendre_polynomial(L.basis(j), -2)

            # Inner product for the left and right enhancements
            norm = basis.integrate_legendre_product(l1, l1)
            cl[i, j] = basis.integrate_legendre_product(l1, ll) / norm
            cr[i, j] = basis.integrate_legendre_product(l1, lr) / norm

    # Put the matrices together
    A = np.vstack((np.hstack((a, b)), cl))
    B = np.vstack((np.hstack((a, b)), cr))
    return A, np.linalg.inv(A), B, np.linalg.inv(B)
コード例 #17
0
    def create(self, remake=False):
        '''
        Populate the wavelength calibration for this aperture,
        using input from the user to try to link up lines between
        the measured arc spectra and the known line wavelengths.
        '''

        self.speak("populating wavelength calibration")

        # loop through until we've decided we're converged
        self.notconverged = True
        while (self.notconverged):

            # set an initial guess matching wavelengths to my pixles
            #self.guessMatches()

            # do a fit
            if self.justloaded:
                self.justloaded = False
            else:
                # do an initial fit

                self.pixelstowavelengths = Legendre.fit(
                    x=self.pixel,
                    y=self.wavelength,
                    deg=self.polynomialdegree,
                    w=self.weights)
            # identify outliers
            limit = 1.48 * craftroom.oned.mad(self.residuals[self.good]) * 4
            limit = np.maximum(limit, 1.0)
            # outliers get reset each time (so don't lose at edges)
            outlier = np.abs(self.residuals) > limit
            # keep track of which are outlier
            for i, m in enumerate(self.matches):
                self.matches[i]['mine']['outlier'] = outlier[i]
            self.speak('points beyond {0} are outliers ({1})'.format(limit, i))

            self.pixelstowavelengths = Legendre.fit(x=self.pixel,
                                                    y=self.wavelength,
                                                    deg=self.polynomialdegree,
                                                    w=self.weights)

            self.plotWavelengthFit()
コード例 #18
0
def dg_interface_flux_matrix(p, T):
    """The interface flux matrices, we use upwinding to get the fluxes
     Denote T the translation necessary to evaluate the flux from the
     other cell.

    There are also analytical expressions for these. See my thesis page 73.

    """

    G0 = smp.zeros(p + 1)
    G1 = smp.zeros(p + 1)
    for i in range(0, p + 1):
        for j in range(0, p + 1):
            li = L.basis(i)
            lj = L.basis(j)
            G0[i, j] = leg.legval(-1, li.coef) * \
                leg.legval(1, lj.coef) * (T**-1)
            G1[i, j] = leg.legval(1, li.coef) * leg.legval(1, lj.coef)

    return G0, G1
コード例 #19
0
ファイル: dg_stability.py プロジェクト: marchdf/dg1d
def dg_interface_flux_matrix(p, T):
    """The interface flux matrices, we use upwinding to get the fluxes
     Denote T the translation necessary to evaluate the flux from the
     other cell.

    There are also analytical expressions for these. See my thesis page 73.

    """

    G0 = smp.zeros(p + 1)
    G1 = smp.zeros(p + 1)
    for i in range(0, p + 1):
        for j in range(0, p + 1):
            li = L.basis(i)
            lj = L.basis(j)
            G0[i, j] = leg.legval(-1, li.coef) * \
                leg.legval(1, lj.coef) * (T**-1)
            G1[i, j] = leg.legval(1, li.coef) * leg.legval(1, lj.coef)

    return G0, G1
コード例 #20
0
ファイル: basis.py プロジェクト: marchdf/dg1d
def shift_legendre_polynomial(l, shift):
    """Returns the Legendre polynomial shifted by a certain amount

    Basically, given a Legendre polynomial l(x), return l(x+shift)
    """

    # Shift the window of the polynomial to get the new coefficients
    ls = L.cast(l, window=l.window - shift)

    # Create a Legendre polynomial with these new coefficients and the
    # default window
    return L(ls.coef)
コード例 #21
0
ファイル: test_basis.py プロジェクト: marchdf/dg1d
    def test_shift_legendre_polynomial(self):
        """Is the shifting of Legendre polynomials correct"""

        # Given a Legendre Polynomial: L(x) = 0.5*(3x^2 -1)
        l = L.basis(2)

        # Evaluate L(x+2)
        ls = basis.shift_legendre_polynomial(l, 2)

        # This should be equal to 5.5 + 6x + 1.5 x^2
        npt.assert_array_almost_equal(ls.convert(
            kind=P).coef, np.array([5.5, 6, 1.5]), decimal=13)
コード例 #22
0
ファイル: test_basis.py プロジェクト: wo6677/dg1d
    def test_integrate_legendre_product(self):
        """Is the integral of the product of two Legendre polynomials correct
        """

        # Given a Legendre Polynomial: L(x) = 0.5*(3x^2 -1)
        l1 = L.basis(2)

        # Evaluate L(x+2)
        l2 = basis.shift_legendre_polynomial(l1, 2)

        # The integral of l1*l2 over [-1,1] = 0.4
        self.assertAlmostEqual(basis.integrate_legendre_product(l1, l2), 0.4)
コード例 #23
0
ファイル: basis.py プロジェクト: wo6677/dg1d
def shift_legendre_polynomial(l, shift):
    """Returns the Legendre polynomial shifted by a certain amount

    Basically, given a Legendre polynomial l(x), return l(x+shift)
    """

    # Shift the window of the polynomial to get the new coefficients
    ls = L.cast(l, window=l.window - shift)

    # Create a Legendre polynomial with these new coefficients and the
    # default window
    return L(ls.coef)
コード例 #24
0
ファイル: test_basis.py プロジェクト: marchdf/dg1d
    def test_integrate_legendre_product(self):
        """Is the integral of the product of two Legendre polynomials correct
        """

        # Given a Legendre Polynomial: L(x) = 0.5*(3x^2 -1)
        l1 = L.basis(2)

        # Evaluate L(x+2)
        l2 = basis.shift_legendre_polynomial(l1, 2)

        # The integral of l1*l2 over [-1,1] = 0.4
        self.assertAlmostEqual(basis.integrate_legendre_product(l1, l2), 0.4)
コード例 #25
0
def nodal_flux_builder_edge(coefficient, diff1, diff2):
    flux_legendre1 = Legendre(coefficient[:5])
    flux_legendre2 = Legendre(coefficient[5:])

    flux_legendre1 = flux_legendre1.deriv(1)
    flux_legendre2 = flux_legendre2.deriv(1)

    flux1 = flux_legendre1.linspace(65, [0, 10])
    flux2 = flux_legendre2.linspace(65, [10, 20])

    total_flux = np.concatenate((diff1 * flux1[1], diff2 * flux2[1]))
    flux_position = np.concatenate((flux1[0], flux2[0]))

    return total_flux, flux_position
コード例 #26
0
    def genNoise(self,grid,maxNoiseOrder):
        """Noise is a matrix of Legendre polynomials of 0<order<maxNoiseOrder
        Additionally 60Hz sine and cosine waves are added to account for the DC component of EEG
        grid -- Grid to be used for timing information
        maxNoiseOrder--Maximum order of noise to be considered"""

        if self.grid() is not None and self.noiseOrders() is not None:
            logger.info( 'Generating noise matrix')
            legpoly = np.array([Legendre.basis(i)(np.arange(len(grid.times()))) for i in self.noiseOrders()]).T #Polynomials
            sw = np.sin(60 * np.arange(len(grid.times())) * 2 * np.pi / float(grid.fs())) # Sine for AC component
            cw = np.cos(60 * np.arange(len(grid.times())) * 2 * np.pi / float(grid.fs())) # Cosine for AC component
            legpoly = np.column_stack((legpoly, sw, cw))
            return pd.DataFrame(legpoly,index=self.grid().times())
コード例 #27
0
ファイル: test_basis.py プロジェクト: wo6677/dg1d
    def test_shift_legendre_polynomial(self):
        """Is the shifting of Legendre polynomials correct"""

        # Given a Legendre Polynomial: L(x) = 0.5*(3x^2 -1)
        l = L.basis(2)

        # Evaluate L(x+2)
        ls = basis.shift_legendre_polynomial(l, 2)

        # This should be equal to 5.5 + 6x + 1.5 x^2
        npt.assert_array_almost_equal(ls.convert(kind=P).coef,
                                      np.array([5.5, 6, 1.5]),
                                      decimal=13)
コード例 #28
0
ファイル: tval.py プロジェクト: timothydmorton/fpp-old
def LDT(epoch,tdur,t,f,pad=0.2,deg=1):
    """
    Local detrending

    A simple function that subtracts a polynomial trend from the
    lightcurve excluding a region around the transit.

    pad : Extra number of days to notch out of the of the transit region.
    """
    bcont = abs(t - epoch) > tdur/2 + pad
    fcont = f[bcont]
    tcont = t[bcont]

    legtrend = Legendre.fit(tcont,fcont,deg,domain=[t.min(),t.max()])
    trend    = legtrend(t)
    return trend 
コード例 #29
0
def nodal_flux_builder_cell(coefficient):
    flux_legendre1 = Legendre(coefficient[:5])
    flux_legendre2 = Legendre(coefficient[5:])

    flux1 = flux_legendre1.linspace(64, [0, 10])
    flux2 = flux_legendre2.linspace(64, [10, 20])

    total_flux = np.concatenate((flux1[1], flux2[1]))
    flux_position = np.concatenate((flux1[0], flux2[0]))

    return total_flux, flux_position
コード例 #30
0
    def genNoise(self, grid, maxNoiseOrder):
        """Noise is a matrix of Legendre polynomials of 0<order<maxNoiseOrder
        Additionally 60Hz sine and cosine waves are added to account for the DC component of EEG
        grid -- Grid to be used for timing information
        maxNoiseOrder--Maximum order of noise to be considered"""

        if self.grid() is not None and self.noiseOrders() is not None:
            logger.info('Generating noise matrix')
            legpoly = np.array([
                Legendre.basis(i)(np.arange(len(grid.times())))
                for i in self.noiseOrders()
            ]).T  #Polynomials
            sw = np.sin(60 * np.arange(len(grid.times())) * 2 * np.pi /
                        float(grid.fs()))  # Sine for AC component
            cw = np.cos(60 * np.arange(len(grid.times())) * 2 * np.pi /
                        float(grid.fs()))  # Cosine for AC component
            legpoly = np.column_stack((legpoly, sw, cw))
            return pd.DataFrame(legpoly, index=self.grid().times())
コード例 #31
0
def regress_poly(degree, data, remove_mean=True, axis=-1):
    """
    Returns data with degree polynomial regressed out.
    :param bool remove_mean: whether or not demean data (i.e. degree 0),
    :param int axis: numpy array axes along which regression is performed
    """
    timepoints = data.shape[0]
    # Generate design matrix
    X = np.ones((timepoints, 1))  # quick way to calc degree 0
    for i in range(degree):
        polynomial_func = Legendre.basis(i + 1)
        value_array = np.linspace(-1, 1, timepoints)
        X = np.hstack((X, polynomial_func(value_array)[:, np.newaxis]))
    non_constant_regressors = X[:, :-1] if X.shape[1] > 1 else np.array([])
    betas = np.linalg.pinv(X).dot(data)
    if remove_mean:
        datahat = X.dot(betas)
    else:  # disregard the first layer of X, which is degree 0
        datahat = X[:, 1:].dot(betas[1:, ...])
    regressed_data = data - datahat
    return regressed_data, non_constant_regressors
コード例 #32
0
def approx_legendre_poly(Moments):
    
    n_moments = Moments.shape[0]-1
    
    exp_coef = (np.zeros((1)))

    # For method description see, for instance: 
    # Chapter 3 of "The Problem of Moments", James Alexander Shohat, Jacob David Tamarkin
    for i in range(n_moments+1):
        p = Legendre.basis(i).convert(window = [0.0,1.0], kind=Polynomial)
       
        q = (2*i+1)*np.sum(Moments[0:(i+1)]*p.coef)
        
        pq = (p.coef*q)
                
        exp_coef = polynomial.polyadd(exp_coef, pq)

            
    expansion = Polynomial(exp_coef)
   
        
    return expansion
コード例 #33
0
ファイル: confounds.py プロジェクト: kesshijordan/nipype
def regress_poly(degree, data, remove_mean=True, axis=-1):
    """
    Returns data with degree polynomial regressed out.

    :param bool remove_mean: whether or not demean data (i.e. degree 0),
    :param int axis: numpy array axes along which regression is performed

    """
    IFLOGGER.debug('Performing polynomial regression on data of shape %s',
                   str(data.shape))

    datashape = data.shape
    timepoints = datashape[axis]

    # Rearrange all voxel-wise time-series in rows
    data = data.reshape((-1, timepoints))

    # Generate design matrix
    X = np.ones((timepoints, 1))  # quick way to calc degree 0
    for i in range(degree):
        polynomial_func = Legendre.basis(i + 1)
        value_array = np.linspace(-1, 1, timepoints)
        X = np.hstack((X, polynomial_func(value_array)[:, np.newaxis]))

    non_constant_regressors = X[:, :-1] if X.shape[1] > 1 else np.array([])

    # Calculate coefficients
    betas = np.linalg.pinv(X).dot(data.T)

    # Estimation
    if remove_mean:
        datahat = X.dot(betas).T
    else:  # disregard the first layer of X, which is degree 0
        datahat = X[:, 1:].dot(betas[1:, ...]).T
    regressed_data = data - datahat

    # Back to original shape
    return regressed_data.reshape(datashape), non_constant_regressors
コード例 #34
0
def regress_poly(degree, data, remove_mean=True, axis=-1):
    """
    Returns data with degree polynomial regressed out.

    :param bool remove_mean: whether or not demean data (i.e. degree 0),
    :param int axis: numpy array axes along which regression is performed

    """
    IFLOGGER.debug('Performing polynomial regression on data of shape %s',
                   str(data.shape))

    datashape = data.shape
    timepoints = datashape[axis]

    # Rearrange all voxel-wise time-series in rows
    data = data.reshape((-1, timepoints))

    # Generate design matrix
    X = np.ones((timepoints, 1))  # quick way to calc degree 0
    for i in range(degree):
        polynomial_func = Legendre.basis(i + 1)
        value_array = np.linspace(-1, 1, timepoints)
        X = np.hstack((X, polynomial_func(value_array)[:, np.newaxis]))

    non_constant_regressors = X[:, :-1] if X.shape[1] > 1 else np.array([])

    # Calculate coefficients
    betas = np.linalg.pinv(X).dot(data.T)

    # Estimation
    if remove_mean:
        datahat = X.dot(betas).T
    else:  # disregard the first layer of X, which is degree 0
        datahat = X[:, 1:].dot(betas[1:, ...]).T
    regressed_data = data - datahat

    # Back to original shape
    return regressed_data.reshape(datashape), non_constant_regressors
コード例 #35
0
ファイル: tracepol.py プロジェクト: AntoineDarveau/jwst-mtl
def specpix_to_wavelength(specpix, tracepars, m=1, frame='dms', oversample=1):
    """Convert the spectral pixel coordinate to wavelength for order m.

    :param specpix: the pixel values.
    :param tracepars: the trace polynomial solutions returned by get_tracepars.
    :param m: the spectral order.
    :param frame: the coordinate frame of the input coordinates (nat, dms or sim).
    :param oversample: the oversampling factor of the input coordinates.

    :type specpix: array[float]
    :type tracepars: dict
    :type m: int
    :type frame: str
    :type oversample: int

    :returns: wavelength - an array containing the wavelengths corresponding to specpix,
    mask - an array that is True when the specpix values were within the valid range of the polynomial.
    :rtype: Tuple(array[float], array[bool])
    """

    # Remove any oversampling.
    specpix = specpix / oversample

    # Convert the input coordinates to nat coordinates.
    specpix_nat = specpix_frame_to_ref(specpix, frame=frame)

    # Convert the specpix coordinates to wavelength.
    spec2w = Legendre(tracepars[m]['wave_coef'],
                      domain=tracepars[m]['wave_domain'])

    with np.errstate(over='ignore'):
        wavelength = np.exp(spec2w(specpix_nat))
    mask = bounds_check(specpix_nat, tracepars[m]['wave_domain'][0],
                        tracepars[m]['wave_domain'][1])

    return wavelength, mask
コード例 #36
0
#starmaster = np.load('/h/mulan0/data/working/GJ1132b_ut160421_22/multipleapertures/aperture_587_1049/extracted'+mastern+'.npy')[()]
#oef, domain =  np.load('/h/mulan0/data/working/GJ1132b_ut160421_22/multipleapertures/aperture_587_1049/aperture_587_1049_wavelengthcalibration.npy')[()]

# this is a check to make sure you have commented in the correct star (this is obviously a hack that could be done better...)
response = raw_input(
    'Is this the correct starmaster for this dataset? [y/n] \n     ' +
    starmasterstr + '\n')
if response == 'n':
    sys.exit(
        'Please edit wavelength_recalibration.py to point to the correct starmaster'
    )
elif response == 'y':
    starmaster = np.load(starmasterstr)[()]

# reacreate the wavelength solution (way to go from pixel space to wavelength space) from mosasaurus
pxtowavemaster = Legendre(coef, domain)
# apertures form a given night (Need to have run mosasaurus in ipython or similar and then copy and paste this monster code in. This is not ideal.)
apertures = r.mask.apertures
makeplot = True
UV_poses, O2_poses, Ca1_poses, Ca2_poses, Ca3_poses, H2O_poses = [], [], [], [], [], []
x_poses = []
#shifts_1132 = []
# Fixed alignment ranges for each prominent freature
align_UV = (6870, 6900)
align_O2 = (7580, 7650)
#align_H2Osmall = (8220, 8260)
align_Ca1 = (8490, 8525)
align_Ca2 = (8535, 8580)
align_Ca3 = (8650, 8700)
align_H2O = (9300, 9700)
for n in r.obs.nScience:
コード例 #37
0
ファイル: tracepol.py プロジェクト: AntoineDarveau/jwst-mtl
def trace_polynomial(trace, m=1, maxorder=15):
    """Fit a polynomial to the trace of order m and return a
    dictionary containing the parameters and validity intervals.

    :param trace: astropy table containing modelled trace points.
    :param m: spectral order for which to fit a polynomial.
    :param maxorder: maximum polynomial order to use.

    :type trace: astropy.table.Table
    :type m: int
    :type maxorder: int

    :returns: pars - dictionary containg the polynomial solution.
    :rtype: dict
    """

    # TODO added arbitrary maxorder to deal with poor exrapolatian, revisit when extrapolation fixed.

    # Select the data for order m.
    mask = (trace['order'] == m)
    wave = trace['Wavelength'][mask]
    spatpix_ref = trace['xpos'][mask]
    specpix_ref = trace['ypos'][mask]

    # Find the edges of the domain.
    wavemin = np.amin(wave)
    wavemax = np.amax(wave)

    specmin = np.amin(specpix_ref)
    specmax = np.amax(specpix_ref)

    # Compute the polynomial parameters for x and y.
    order = 0
    while order <= maxorder:

        spatpol = Legendre.fit(np.log(wave), spatpix_ref, order)
        specpol = Legendre.fit(np.log(wave), specpix_ref, order)

        spatpixp_nat = spatpol(np.log(wave))
        specpixp_nat = specpol(np.log(wave))

        if np.all(np.abs(spatpix_ref - spatpixp_nat) < 0.5) & np.all(
                np.abs(specpix_ref - specpixp_nat) < 0.5):
            break

        order += 1

    # Compute the transform back to wavelength.
    wavegrid = wavemin + (wavemax - wavemin) * np.linspace(0., 1., 501)
    specgrid = specpol(np.log(wavegrid))
    wavepol = Legendre.fit(specgrid, np.log(wavegrid), order)

    # Add the parameters to a dictionary.
    pars = dict()
    pars['spat_coef'] = spatpol.coef
    pars['spat_domain'] = spatpol.domain
    pars['spec_coef'] = specpol.coef
    pars['spec_domain'] = specpol.domain
    pars['wave_coef'] = wavepol.coef
    pars['wave_domain'] = wavepol.domain

    return pars
コード例 #38
0
class WavelengthCalibrator(Talker):
    def __init__(self,  aperture,
                        elements=['He', 'Ne','Ar'],
                        polynomialdegree=3,
                        matchdistance=100):
        Talker.__init__(self)

        # keep track of the aperture this belongs to
        self.aperture = aperture

        # set up the some defaults
        self.elements = elements
        self.polynomialdegree = polynomialdegree
        self.matchdistance = matchdistance

        # either load a previous wavecal, or create a new one
        self.populate()

    @property
    def wavelengthprefix(self):
        return self.aperture.directory + '{0}_'.format(self.aperture.name)

    @property
    def calibrationfilename(self):
        return self.wavelengthprefix + 'wavelengthcalibration.npy'

    @property
    def waveidfilename(self):
        return self.wavelengthprefix + 'waveids.txt'

    def loadWavelengthIdentifications(self, restart=False):
        ''' Try loading a custom stored wavelength ID file
            from the aperture's directory, and if that
            doesn't work, load the default one for this grism.'''

        try:
            # try to load a custom wavelength id file
            assert(restart == False)

            self.speak('checking for a custom wavelength-to-pixel file')
            self.waveids = astropy.io.ascii.read(self.waveidfilename)

            # keep track of which waveid file is used
            self.whichwaveid = 'Aperture-Specific ({0})'.format(
                self.waveidfilename.split('/')[-1])

        except (IOError,AssertionError):
            self.speak('no custom wavelength-to-pixel files found')

            # load the default for this grism, as set in obs. file
            d = astropy.io.ascii.read(self.aperture.obs.wavelength2pixelsFile)
            self.rawwaveids = d[['pixel', 'wavelength', 'name']]
            self.rawwaveids['pixel'] /= self.aperture.obs.binning

            # use a cross-corrlation to find the rough offset
            #  (the function call will define waveids)
            # pull out the peaks from extracted arc spectra
            self.findPeaks()

            # perform a first rough alignment, to make pixels
            self.findRoughShift()


            # keep track of whcih waveid file is used
            self.whichwaveid = 'Default ({0})'.format(
                self.aperture.obs.wavelength2pixelsFile.split('/')[-1])


        self.speak('loaded {0}:'.format(self.whichwaveid))
        self.findKnownWavelengths()
        self.guessMatches()

    def saveWavelengthIdentification(self):
        '''store the wavelength-to-pixel identifications'''

        self.waveids.write( self.waveidfilename,
                            format='ascii.fixed_width',
                            delimiter='|',
                            bookend=False)
        self.speak('saved wavelength-to-pixel matches to '+self.waveidfilename)

    def save(self):

        self.speak('saving all aspects of the wavelength calibration')

        # the wavelength identifications
        self.saveWavelengthIdentification()

        # the actual matches used (with outliers, etc...)
        self.saveMatches()

        # save the coefficients (and domain) of the calibration
        self.saveCalibration()

        # save the figure
        self.figcal.savefig(self.wavelengthprefix + 'calibration.pdf')

    def loadCalibration(self):

        coef, domain = np.load(self.calibrationfilename)
        self.pixelstowavelengths = Legendre(coef, domain)
        self.polynomialdegree = self.pixelstowavelengths.degree()
        self.speak("loaded wavelength calibration"
                "from {0}".format(self.calibrationfilename))

    def saveCalibration(self):
        np.save(self.calibrationfilename, (self.pixelstowavelengths.coef, self.pixelstowavelengths.domain))
        self.speak("saved wavelength calibration coefficients to {0}".format(self.calibrationfilename))

    def populate(self, restart=False):

        # populate the wavelength identifications
        self.loadWavelengthIdentifications(restart=restart)
        try:
            # populate the wavelength calibration polynomial
            assert(restart==False)
            self.loadCalibration()
            self.justloaded = True
            self.loadMatches()
            #self.plotWavelengthFit(interactive=False)
            #unhappy = ('n' in self.input('Are you happy with the wavelength calibration? [Y,n]').lower())
            #assert(unhappy == False)
        except (IOError, AssertionError):
            self.justloaded = False
            self.create()


    def findRoughShift(self, blob=2.0):
        '''using a list of pixels matched to wavelengths, find the rough
                offset of the this arc relative to the standard slits/setting'''

        self.speak("cross correlating arcs with known wavelengths")

        # create a plot showing how well the lines match
        figure_waverough = plt.figure(  'wavelength rough offset',
                                        figsize=(6,4), dpi=100)
        gs = plt.matplotlib.gridspec.GridSpec(  3,2,
                                                bottom=0.15, top=0.85,
                                                hspace=0.1,wspace=0,
                                                width_ratios=[1, .5])

        # make the axes for checking the rough alignment
        self.ax_waverough, self.ax_wavecor = {}, {}
        sharer, sharec = None, None
        for i, e in enumerate(self.elements):
            self.ax_waverough[e] = plt.subplot(gs[i,0],sharex=sharer)
            self.ax_wavecor[e] = plt.subplot(gs[i,1], sharex=sharec)
            sharer, sharec = self.ax_waverough[e], self.ax_wavecor[e]

        # calculate correlation functions
        self.corre = {}
        for count, element in enumerate(self.elements):
            # pull out the peaks
            xPeak = [p['w'] for p in self.peaks[element]]
            yPeak = [p['intensity'] for p in self.peaks[element]]

            # create fake spectra using the line positions (reference + new)
            x = np.arange(-self.aperture.obs.ysize,self.aperture.obs.ysize)

            myPeaks, theirPeaks = np.zeros(len(x)), np.zeros(len(x))
            # create fake spectrum of their peaks
            for i in range(len(self.rawwaveids)):
                if element in self.rawwaveids['name'][i]:
                    center = self.rawwaveids['pixel'][i]
                    theirPeaks += np.exp(-0.5*((x-center)/blob)**2)

            # create fake spectrum of my peaks
            for i in range(len(xPeak)):
                center = xPeak[i]
                myPeaks += np.exp(-0.5*((x-center)/blob)**2)*np.log(yPeak[i])

            # calculate the correlation function for this element
            self.corre[element] = np.correlate(myPeaks, theirPeaks, 'full')

            # plot the rough shift and identifications
            self.ax_waverough[element].plot(x, myPeaks/myPeaks.max(),
                                            label='extracted', alpha=0.5,
                                            color=colors[element])
            self.ax_waverough[element].set_ylim(0, 2)


            # plot the correlation functions
            normcor = self.corre[element]
            assert(np.isfinite(self.corre[element]).any())

            normcor /= np.nanmax(self.corre[element])
            self.ax_wavecor[element].plot(normcor, label=element, alpha=0.5,
                                            color=colors[element])
            # tidy up the plots
            for a in [self.ax_wavecor, self.ax_waverough]:
                plt.setp(a[element].get_xticklabels(), visible=False)
                plt.setp(a[element].get_yticklabels(), visible=False)

            # multiply the correlation functions together
            assert(np.isfinite(self.corre[element]).any())
            if count == 0:
                self.corre['combined'] = np.ones_like(self.corre[element])
            self.corre['combined'] *= self.corre[element]

        # find the peak of the combined correlation function
        self.peakoffset = -1024 # KLUDGE KLUDGE KLUDGE! np.where(self.corre['combined'] == self.corre['combined'].max())[0][0] - len(x)
        # (old?) to convert: len(x) - xPeak = x + peakoffset

        # define the new, shifted, waveids array
        self.waveids = copy.deepcopy(self.rawwaveids)
        self.waveids['pixel'] += self.peakoffset

        # plot the shifted wavelength ids, and combined corfuncs
        for element in self.elements:
            for i in range(len(self.rawwaveids)):
                if element in self.rawwaveids['name'][i]:
                      center = self.waveids['pixel'][i]
                      self.ax_waverough[element].axvline(center, alpha=0.25, color='black')
            # plot the combined correlation function
            normedcombined = self.corre['combined']/np.max(self.corre['combined'])
            self.ax_wavecor[element].plot(normedcombined,
                                label='combined', alpha=0.25, color='black')
            # tidy up the plots
            self.ax_waverough[element].set_ylabel(element)
            ax = self.ax_waverough[self.elements[-1]]
            plt.setp(ax.get_xticklabels(), visible=True)
            ax.set_xlabel('Pixel Position')
            fontsize = 8
            self.ax_wavecor[self.elements[0]].set_title('cross correlation peaks at \n{0} pixels ({1}x{1} binned pixels)'.format(self.peakoffset, self.aperture.obs.binning), fontsize=fontsize)
            self.ax_waverough[self.elements[0]].set_title(
            'Coarse Wavelength Alignment\nfor ({0:0.1f},{1:0.1f})'.format(
                 self.aperture.x, self.aperture.y),fontsize=fontsize)

            # save the figure
            figure_waverough.savefig(
            self.aperture.directory + 'roughWavelengthAlignment_{0}.pdf'.format(
            self.aperture.name))

    @property
    def peaks(self):
        try:
            return self._peaks
        except AttributeError:
            self.findPeaks()
            return self._peaks

    def findPeaks(self):
        '''identify peaks in the extracted arc spectrum'''

        # extract a spectrum from the master image for each lamp
        self.aperture.arcs = {}
        for element in self.elements:
            self.aperture.arcs[element] = self.aperture.extract(
                                n=element,
                                image=self.aperture.images[element],
                                arc=True)



        # find the peaks in my spectra
        self._peaks = {}

        for count, element in enumerate(self.elements):

            # the pixel spectrum self.aperture.extracted from this arc lamp
            width = np.min(self.aperture.trace.extractionwidths)
            flux = self.aperture.arcs[element][width]['raw_counts']

             # identify my peaks
            xPeak, yPeak, xfiltered, yfiltered = zachopy.oned.peaks(
                                                self.aperture.waxis,
                                                flux,
                                                plot=False,
                                                xsmooth=30,
                                                threshold=100,
                                                edgebuffer=10,
                                                widthguess=1,
                                                maskwidth=3,
                                                returnfiltered=True)
            self.aperture.arcs[element]['filtered'] = yfiltered

            # for some reason, need to trim peaks outside range
            pad = 25
            toright = xPeak > (np.min(self.aperture.waxis) + pad)
            toleft = xPeak < (np.max(self.aperture.waxis) - pad)
            ok = toleft*toright
            xPeak, yPeak = xPeak[ok], yPeak[ok]

            n = len(xPeak)

            # store those peaks
            self._peaks[element] = []
            for i in range(n):
                peak =  {
                        'w':xPeak[i],
                        'intensity':yPeak[i],
                        'handpicked':False,
                        'outlier':False
                        }
                self._peaks[element].append(peak)

    def findKnownWavelengths(self):
        # create a temporary calibration to match reference wavelengths to reference pixels (so we can extrapolate to additional wavelengths not recorded in the dispersion solution file)




        self.knownwavelengths = {}

        # treat the arc lamps separately
        for count, element in enumerate(self.elements):
            # pull out the wavelengths from the complete file
            self.knownwavelengths[element] = []
            for i in range(len(self.waveids)):
                if element in self.waveids['name'][i]:
                    wave = self.waveids['wavelength'][i]
                    pixel = self.waveids['pixel'][i]
                    known = {
                            'element':element,
                            'wavelength':wave,
                            'pixelguess':pixel
                            }
                    self.knownwavelengths[element].append(known)

    @property
    def matchesfilename(self):
        return self.wavelengthprefix + 'wavelengthmatches.npy'

    def saveMatches(self):
        self.speak('saving wavelength dictionaries to {}'.format(self.matchesfilename))
        np.save(self.matchesfilename, (self.matches, self.knownwavelengths))

    def loadMatches(self):
        (self.matches, self.knownwavelengths) = np.load(self.matchesfilename)
        self.speak('loaded wavelength matches from {0}'.format(self.matchesfilename))

    def guessMatches(self):
        self.matches = []

        # do identification with one arc at a time
        for element in self.elements:
            # pull out my peaks and theirs
            myPeaks = np.array([p['w'] for p in self.peaks[element]])
            theirPeaksOnMyPixels = np.array([known['pixelguess']
                                            for known
                                            in self.knownwavelengths[element]])

            thiselement = []
            # loop over my peaks
            for i in range(len(myPeaks)):

                # find my closest peak to theirs
                distance = myPeaks[i] - theirPeaksOnMyPixels
                closest = np.sort(np.nonzero(np.abs(distance) == np.min(np.abs(distance)))[0])[0]

                if distance[closest] < self.matchdistance:
                    # call this a match (no matter how far)
                    match = {
                            'mine':self.peaks[element][i],
                            'theirs':self.knownwavelengths[element][closest],
                            'distance':distance[closest]
                            }
                    thiselement.append(match)

            self.matches.extend(thiselement)
            '''for theirs in theirPeaksOnMyPixels:


                relevant = [m for m in thiselement if
                            m['theirs']['pixelguess'] == theirs]

                if len(relevant) > 0:
                    print "{0} of my peaks match to their {1}".format(len(relevant), theirs)

                    distances = np.abs([m['theirs']['pixelguess'] - m['mine']['w'] for m in relevant])


                    best = distances.argmin()

                    print "the closests one is {0}".format(relevant[best]['mine']['w'])
                    self.matches.append(relevant[best])
            '''



            # add this to the list
            #self.matches.extend(thiselement)

    def create(self, remake=False):
        '''Populate the wavelength calibration for this aperture.'''

        self.speak("populating wavelength calibration")



        # loop through
        self.notconverged = True
        while(self.notconverged):

            # set an initial guess matching wavelengths to my pixles
            #self.guessMatches()

            # do a fit
            if self.justloaded:
                self.justloaded = False
            else:
                # do an initial fit
                self.pixelstowavelengths = Legendre.fit(
                                                    x=self.pixel,
                                                    y=self.wavelength,
                                                    deg=self.polynomialdegree,
                                                    w=self.weights
                                                    )
            # identify outliers
            limit = 1.48*zachopy.oned.mad(self.residuals[self.good])*4
            limit = np.maximum(limit, 1.0)
            # outliers get reset each time (so don't lose at edges)
            outlier = np.abs(self.residuals) > limit
            # keep track of which are outlier
            for i, m in enumerate(self.matches):
                self.matches[i]['mine']['outlier'] = outlier[i]
            self.speak('points beyond {0} are outliers ({1})'.format(limit, i))

            self.pixelstowavelengths = Legendre.fit(
                                                    x=self.pixel,
                                                    y=self.wavelength,
                                                    deg=self.polynomialdegree,
                                                    w=self.weights)

            self.plotWavelengthFit()
            #self.updatew2p()

    @property
    def weights(self):
        return self.good+self.handpicked*10



    @property
    def residuals(self):
        return  self.wavelength - self.pixelstowavelengths(self.pixel)

    @property
    def good(self):
        isntoutlier = np.array([m['mine']['outlier'] == False for m in self.matches])
        ishandpicked = np.array([m['mine']['handpicked'] for m in self.matches])
        return (isntoutlier + ishandpicked) > 0

    @property
    def pixel(self):
        return np.array([m['mine']['w'] for m in self.matches])

    @property
    def pixelelement(self):
        return np.array([m['theirs']['element'] for m in self.matches])

    @property
    def intensity(self):
        return np.array([m['mine']['intensity'] for m in self.matches])

    @property
    def wavelength(self):
        return np.array([m['theirs']['wavelength'] for m in self.matches])

    @property
    def emissioncolor(self):
        return np.array([colors[m['theirs']['element']] for m in self.matches])

    def plotWavelengthFit(self, interactive=True):
        # plot to make sure the wavelength calibration makes sense

        self.speak('{}, {}'.format(self.pixelstowavelengths, self.pixelstowavelengths.domain, self.pixelstowavelengths.window))
        self.figcal = plt.figure('wavelength calibration',
                                        figsize=(15,6), dpi=72)
        self.interactivewave = zachopy.iplot.iplot(4,1,
                height_ratios=[0.1, 0.4, 0.2, .2], hspace=0.1,
                bottom=0.15)

        self.ax_header = self.interactivewave.subplot(0)

        # do the lamp spectra overlap?
        self.ax_walign = self.interactivewave.subplot(1)

        # what is the actual wavelength calibration
        self.ax_wcal = self.interactivewave.subplot(2, sharex=self.ax_walign)

        # what are the residuals from the fit
        self.ax_wres = self.interactivewave.subplot(3, sharex=self.ax_walign)

        for ax in [self.ax_header, self.ax_walign, self.ax_wcal]:
            plt.setp(ax.get_xticklabels(), visible=False)

        self.ax_header.set_title("Wavelength Calib. for Aperture"
            " (%0.1f,%0.1f)" % (self.aperture.x, self.aperture.y))

        # print information about the wavelength calibration
        self.ax_header.set_xlim(0,1)
        self.ax_header.set_ylim(0,1)
        plt.setp(self.ax_header.get_yticklabels(), visible=False)
        self.ax_header.patch.set_visible(False)
        text = 'Wavelength-to-pixel guesses are {0}\n'.format(self.whichwaveid)
        text += 'Hand-picked matches are from {0}.\n'.format(self.matchesfilename.split('/')[-1])
        text += 'Pixel-to-wavelength calibration is '
        if self.justloaded:
            text += 'from ' + self.calibrationfilename.split('/')[-1]
        else:
            text += '[new!]'
        self.ax_header.text(0.025, 0.5, text,
                                    va='center', ha='left', fontsize=10)
        for i, e in enumerate(self.elements):
            self.ax_header.text(0.98,
                                1.0-(i+1.0)/(len(self.elements)+1),
                                e,
                                ha='right', va='center',
                                color=colors[e],
                                fontsize=6)

        # plot the backwards calibration
        for e in self.elements:
            ok = np.nonzero([e in self.waveids['name'][i] for i in range(len(self.waveids))])[0]

            xvals = self.waveids['pixel'][ok]
            yvals = self.waveids['wavelength'][ok]

            self.ax_header.scatter(xvals, yvals,
                            marker='o', color=colors[e], alpha=0.3)
            xfine = np.linspace(min(self.waveids['wavelength']), max(self.waveids['wavelength']))
            #self.ax_header.plot(self.waveids, xfine,
            #                    alpha=0.3)

        # plot the overlap of the lamp spectra
        scatterkw = dict(   marker='o', linewidth=0,
                            alpha=0.5, color=self.emissioncolor)
        for element in self.elements:

            self.ax_walign.plot(self.aperture.waxis,
                                self.aperture.arcs[element]['filtered'],
                                color=colors[element], alpha=0.5)
            self.ax_walign.scatter(self.pixel, self.intensity, **scatterkw)

            self.ax_walign.set_yscale('log')
            self.ax_walign.set_ylim(1, None)


        # plot tick marks for the known wavelengths
        for e in self.elements:
            for tw in self.knownwavelengths[e]:
                pix = tw['pixelguess']
                self.ax_walign.axvline(pix,
                                        ymin=0.9,
                                        color=colors[e],
                                        alpha=0.5)


        # plot the calibration
        x = np.linspace(*zachopy.oned.minmax(self.aperture.waxis),num=200)
        self.ax_wcal.plot(x, self.pixelstowavelengths(x), alpha=0.5, color='black')
        self.ax_wcal.set_ylabel('Wavelength (angstroms)')

        self.ax_wres.set_ylabel('Residuals')
        scatterkw['color'] = self.emissioncolor[self.good]
        self.ax_wcal.scatter(   self.pixel[self.good],
                                self.wavelength[self.good],
                                **scatterkw)
        self.ax_wres.scatter(   self.pixel[self.good],
                                self.residuals[self.good],
                                **scatterkw)
        scatterkw['marker'] = 'x'
        bad = self.good == False
        scatterkw['color'] = self.emissioncolor[bad]
        self.ax_wcal.scatter(self.pixel[bad], self.wavelength[bad], **scatterkw)
        self.ax_wres.scatter(self.pixel[bad], self.residuals[bad], **scatterkw)
        self.ax_wres.set_xlabel('Pixel # (by python rules)')

        self.ax_wres.set_xlim(*zachopy.oned.minmax(self.aperture.waxis))

        self.ax_walign.scatter(   self.pixel[self.handpicked],
                                self.intensity[self.handpicked],
                                marker='+', color='black')

        self.ax_wcal.scatter(   self.pixel[self.handpicked],
                                self.wavelength[self.handpicked],
                                marker='+', color='black')

        self.ax_wres.scatter(   self.pixel[self.handpicked],
                                self.residuals[self.handpicked],
                                marker='+', color='black')

        self.ax_wres.axhline(0, linestyle='--', color='gray', zorder=-100)


        rms = np.std(self.residuals[self.good])
        performance = 'using a {}-degree {}'.format(self.polynomialdegree,
                            self.pixelstowavelengths.__class__.__name__)
        performance += ' the calibration has an RMS of {0:.2f}A'.format(rms)
        performance += ' with {0:.0f} good points'.format(np.sum(self.good))
        performance += ' from {:.0f} to {:.0f}'.format(*zachopy.oned.minmax(self.pixel[self.good]))
        self.ax_wres.text(0.98, 0.05, performance,
                            fontsize=10,
                            ha='right', va='bottom',
                            transform=self.ax_wres.transAxes)

        self.ax_wres.set_ylim(*np.array([-1,1])*np.maximum(zachopy.oned.mad(self.residuals[self.good])*10, 1))
        plt.draw()
        self.speak('check out the wavelength calibration')

        if interactive == False:
            return

        options = {}
        options['q'] = dict(description='[q]uit without writing',
                            function=self.quit,
                            requiresposition=False)

        options['w'] = dict(description='[w]rite the new calibration',
                            function=self.saveandquit,
                            requiresposition=False)

        options['h'] = dict(description='match up a [h]elium line',
                            function=self.see,
                            requiresposition=True)
        options['n'] = dict(description='match up a [n]eon line',
                            function=self.see,
                            requiresposition=True)
        options['a'] = dict(description='match up a [a]rgon line',
                            function=self.see,
                            requiresposition=True)

        options['d'] = dict(description='change the polynomial [d]egree',
                            function=self.changedegree,
                            requiresposition=True)

        options['z'] = dict(description='[z]ap a previously matched line',
                            function=self.zap,
                            requiresposition=True)

        options['u'] = dict(description='[u]ndo all changes made this session',
                                    function=self.undo,
                                    requiresposition=False)

        options['r'] = dict(description='[r]estart from all defaults',
                                    function=self.restart,
                                    requiresposition=False)

        options['!'] = dict(description='raise an error[!]',
                                    function=self.freakout,
                                    requiresposition=False)

        self.speak('your options include:')
        for v in options.values():
            self.speak('   ' + v['description'])
        pressed = self.interactivewave.getKeyboard()

        try:
            # figure out which option we're on
            thing = options[pressed.key.lower()]
            # check that it's a valid position, if need be
            if thing['requiresposition']:
                assert(pressed.inaxes is not None)
            # execute the function associated with this option
            thing['function'](pressed)
        except KeyError:
            self.speak("nothing yet defined for [{}]".format(pressed.key))
            return
        except AssertionError:
            self.speak("that didn't seem to be at a valid position!")
            return

    def freakout(self, *args):
        raise RuntimeError('Breaking here, for debugginng')

    def undo(self, *args):
        self.populate()

    def restart(self, *args):
        self.populate(restart=True)

    def changedegree(self, pressed):
        '''change the degree of the polynomial'''
        self.speak('please enter a new polynomial degree (1-9)')
        new = self.interactivewave.getKeyboard()

        self.polynomialdegree = int(new.key)

    def quit(self, *args):
        '''quit without saving'''
        self.speak('finished!')
        self.notconverged = False

    def saveandquit(self, *args):
        '''save and quit'''
        self.speak('saving the wavelength calibration')
        self.save()
        self.quit()



    @property
    def handpicked(self):
        return np.array([m['mine']['handpicked'] for m in self.matches])

    def selectpeak(self, pressed):


        #print pressed.xdata, pressed.ydata
        element = shortcuts[pressed.key]
        self.speak('matching for the closest {0} line'.format(element))

        valid = np.where([e == element for e in self.pixelelement])[0]
        closest = valid[np.argmin(np.abs(self.pixel[valid] - pressed.xdata))]
        return closest


    def selectwavelength(self, pressed):
        #print pressed.xdata, pressed.ydata
        element = shortcuts[pressed.key.lower()]
        pixguess = [w['pixelguess'] for w in self.knownwavelengths[element]]
        closest = ((pixguess - pressed.xdata)**2).argmin()
        return self.knownwavelengths[element][closest]

    def checkvalid(self, pressed):
        if pressed.xdata is None:
            self.speak('''you typed "{0}", but the window didn't return a coordinate, please try again!'''.format(pressed.key))
            return False
        else:
            return True

    def zap(self, pressed):
        if self.checkvalid(pressed) == False:
            return
        closest = ((self.pixel - pressed.xdata)**2).argmin()
        self.matches[closest]['mine']['handpicked'] = False


    def see(self, pressed):
        if self.checkvalid(pressed) == False:
            return

        closest = self.selectpeak(pressed)
        if closest is None:
            return

        match = self.matches[closest]

        pixel = match['mine']['w']
        self.speak('you are e[X]cited about an indentifiable emission line at {0:.1}'.format(pixel))
        self.ax_walign.axvline(pixel, alpha=0.5, color=colors[match['theirs']['element']])


        self.speak('  now, please click a [H]e, [N]e, or [A]r line')

        # get a keyboard input at mouse position
        secondpressed = self.interactivewave.getKeyboard()
        if secondpressed.key.lower() in ['h', 'n', 'a']:
            # pull out the closest wavelength match
            wavematch = self.selectwavelength(secondpressed)
        else:
            self.speak("hmmm...ignoring")
            return


        self.matches[closest]['mine']['handpicked'] = True
        self.speak('updating the guess for {0}A from {1} to {2}'.format(
                        wavematch['wavelength'],
                        wavematch['pixelguess'],
                        pixel
        ))
        self.matches[closest]['theirs']['pixelguess'] = pixel
        self.matches[closest]['mine']['wavelength'] = wavematch['wavelength']


    def exclude(self, pressed):
        self.speak('excluding a point')
        if pressed.inaxes == self.ax_wres:
            closest = self.select(pressed)
            self.matches[closest]['mine']['handpicked'] = False
コード例 #39
0
def fit_legendres_images(images,
                         centers,
                         lg_inds,
                         rad_inds,
                         maxPixel,
                         rotate=0,
                         image_stds=None,
                         image_counts=None,
                         image_nanMaps=None,
                         image_weights=None,
                         chiSq_fit=False,
                         rad_range=None):
    """
    Fits legendre polynomials to an array of single images (3d) or a list/array of 
    an array of scan images, possible dimensionality:
        1) [NtimeSteps, image_rows, image_cols]
        2) [NtimeSteps (list), Nscans, image_rows, image_cols]
    """

    if image_counts is None:
        image_counts = []
        for im in range(len(images)):
            image_counts.append(np.ones_like(images[im]))
            image_counts[im][np.isnan(images[im])] = 0

    if chiSq_fit and (image_stds is None):
        print("If using the chiSq fit you must supply image_stds")
        return None

    if image_stds is None:
        image_stds = []
        for im in range(len(images)):
            image_stds.append(np.ones_like(images[im]))
            image_stds[im][np.isnan(images[im])] = 0

    with_scans = len(images[0].shape) + 1 >= 4

    img_fits = [[] for x in range(len(images))]
    img_covs = [[] for x in range(len(images))]
    for rad in range(maxPixel):
        if rad_range is not None:
            if rad < rad_range[0] or rad >= rad_range[1]:
                continue
        if rad % 25 == 0:
            print("Fitting radius {}".format(rad))

        pixels, nans, angles = [], [], []
        all_angles = np.arctan2(rad_inds[rad][1].astype(float),
                                rad_inds[rad][0].astype(float))
        all_angles[all_angles < 0] += 2 * np.pi
        all_angles = np.mod(all_angles + rotate, 2 * np.pi)
        all_angles[all_angles > np.pi] -= 2 * np.pi
        if np.sum(np.mod(lg_inds, 2)) == 0:
            all_angles[np.abs(all_angles) > np.pi / 2.] -= np.pi * np.sign(
                all_angles[np.abs(all_angles) > np.pi / 2.])
        angles = np.unique(np.abs(all_angles))
        ang_sort_inds = np.argsort(angles)
        angles = angles[ang_sort_inds]
        Nangles = angles.shape[0]

        if len(angles) == len(all_angles):
            do_merge = False
        else:
            do_merge = True
            mi_rows, mi_cols, mi_data = [], [], []
            pr, pc, pv = [], [], []
            for ia, ang in enumerate(angles):
                inds = np.where(np.abs(all_angles) == ang)[0]
                mi_rows.append(np.ones_like(inds) * ia)
                mi_cols.append(inds)
            mi_rows, mi_cols = np.concatenate(mi_rows), np.concatenate(mi_cols)

            merge_indices = csr_matrix(
                (np.ones_like(mi_rows), (mi_rows, mi_cols)),
                shape=(len(angles), len(all_angles)))

        for im in range(len(images)):

            if with_scans:
                angs_tile = np.tile(angles, (images[im].shape[0], 1))
                scn_inds, row_inds, col_inds = [], [], []
                for isc in range(images[im].shape[0]):
                    scn_inds.append(
                        np.ones(rad_inds[rad][0].shape[0], dtype=int) * isc)
                    row_inds.append(rad_inds[rad][0] + centers[im][isc, 0])
                    col_inds.append(rad_inds[rad][1] + centers[im][isc, 1])

                scn_inds = np.concatenate(scn_inds)
                row_inds = np.concatenate(row_inds)
                col_inds = np.concatenate(col_inds)
                img_pixels = np.reshape(
                    copy(images[im][scn_inds, row_inds, col_inds]),
                    (images[im].shape[0], -1))
                img_counts = np.reshape(
                    copy(image_counts[im][scn_inds, row_inds, col_inds]),
                    (images[im].shape[0], -1))
                img_stds = np.reshape(
                    copy(image_stds[im][scn_inds, row_inds, col_inds]),
                    (images[im].shape[0], -1))
                if image_nanMaps is not None:
                    img_pixels[np.reshape(
                        image_nanMaps[im][scn_inds, row_inds, col_inds],
                        (images[im].shape[0], -1)).astype(bool)] = np.nan
                    img_counts[np.reshape(
                        image_nanMaps[im][scn_inds, row_inds, col_inds],
                        (images[im].shape[0], -1)).astype(bool)] = 0
                if image_weights is not None:
                    img_weights = np.reshape(
                        copy(image_weights[im][scn_inds, row_inds, col_inds]),
                        (images[im].shape[0], -1))
            else:
                angs_tile = np.expand_dims(angles, 0)
                row_inds = rad_inds[rad][0] + centers[im, 0]
                col_inds = rad_inds[rad][1] + centers[im, 1]
                img_pixels = np.reshape(copy(images[im][row_inds, col_inds]),
                                        (1, -1))
                img_counts = np.reshape(
                    copy(image_counts[im][row_inds, col_inds]), (1, -1))
                img_stds = np.reshape(copy(image_stds[im][row_inds, col_inds]),
                                      (1, -1))
                if image_nanMaps is not None:
                    img_pixels[np.reshape(
                        image_nanMaps[im][row_inds, col_inds],
                        (1, -1)).astype(bool)] = np.nan
                    img_counts[np.reshape(
                        image_nanMaps[im][row_inds, col_inds],
                        (1, -1)).astype(bool)] = 0
                if image_weights is not None:
                    img_weights = np.reshape(
                        copy(image_weights[im][row_inds, col_inds]), (1, -1))

            img_pix = img_pixels * img_counts
            img_var = img_counts * (img_stds**2)
            img_pix[np.isnan(img_pixels)] = 0
            img_var[np.isnan(img_pixels)] = 0

            if do_merge:

                img_pixels[np.isnan(img_pixels)] = 0

                img_pix = np.transpose(merge_indices.dot(
                    np.transpose(img_pix)))
                img_var = np.transpose(merge_indices.dot(
                    np.transpose(img_var)))

                img_counts = np.transpose(
                    merge_indices.dot(np.transpose(img_counts)))

                if image_weights is not None:
                    print("Must fill this in, don't forget std option")
                    sys.exit(0)
            else:
                img_pix = img_pix[:, ang_sort_inds]
                img_var = img_var[:, ang_sort_inds]
                img_counts = img_counts[:, ang_sort_inds]
            img_pix /= img_counts
            img_var /= img_counts

            Nnans = np.sum(np.isnan(img_pix), axis=-1)
            ang_inds = np.where(img_counts > 0)
            arr_inds = np.concatenate(
                [np.arange(Nangles - Nn) for Nn in Nnans])

            img_pixels = np.zeros_like(img_pix)
            img_vars = np.zeros_like(img_var)
            img_angs = np.zeros_like(img_pix)
            img_dang = np.zeros_like(img_pix)

            img_pixels[ang_inds[0][:-1], arr_inds[:-1]] =\
                    (img_pix[ang_inds[0][:-1], ang_inds[1][:-1]] + img_pix[ang_inds[0][1:], ang_inds[1][1:]])/2.
            img_vars[ang_inds[0][:-1], arr_inds[:-1]] =\
                    (img_var[ang_inds[0][:-1], ang_inds[1][:-1]] + img_var[ang_inds[0][1:], ang_inds[1][1:]])/2.
            img_angs[ang_inds[0][:-1], arr_inds[:-1]] =\
                    (angs_tile[ang_inds[0][:-1], ang_inds[1][:-1]] + angs_tile[ang_inds[0][1:], ang_inds[1][1:]])/2.
            img_dang[ang_inds[0][:-1], arr_inds[:-1]] =\
                    (angs_tile[ang_inds[0][1:], ang_inds[1][1:]] - angs_tile[ang_inds[0][:-1], ang_inds[1][:-1]])

            for isc in range(Nnans.shape[0]):
                # Using angle midpoint => one less angle => Nnans[isc]+1
                img_pixels[isc, -1 * (Nnans[isc] + 1):] = 0
                img_vars[isc, -1 * (Nnans[isc] + 1):] = 0
                img_angs[isc, -1 * (Nnans[isc] + 1):] = 0
                img_dang[isc, -1 * (Nnans[isc] + 1):] = 0

            if image_weights is not None:
                print("Must fill this in and check below")
                sys.exit(0)
            elif chiSq_fit:
                img_weights = 1. / img_vars
                img_weights[img_vars == 0] = 0
            else:
                img_weights = np.ones_like(img_pixels)
            img_weights *= np.sin(img_angs) * img_dang
            lgndrs = []
            for lg in lg_inds:
                lgndrs.append(Legendre.basis(lg)(np.cos(img_angs)))
            lgndrs = np.transpose(np.array(lgndrs), (1, 0, 2))

            empty_scan = np.sum(img_weights.astype(bool), -1) < 2
            overlap = np.einsum('bai,bi,bci->bac',
                                lgndrs[np.invert(empty_scan)],
                                img_weights[np.invert(empty_scan)],
                                lgndrs[np.invert(empty_scan)],
                                optimize='greedy')
            empty_scan[np.invert(empty_scan)] = (np.linalg.det(overlap) == 0.0)

            if np.any(empty_scan):
                fit = np.ones((img_pixels.shape[0], len(lg_inds))) * np.nan
                cov = np.ones(
                    (img_pixels.shape[0], len(lg_inds), len(lg_inds))) * np.nan

                if np.any(np.invert(empty_scan)):
                    img_pixels = img_pixels[np.invert(empty_scan)]
                    img_weights = img_weights[np.invert(empty_scan)]
                    img_vars = img_vars[np.invert(empty_scan)]
                    lgndrs = lgndrs[np.invert(empty_scan)]

                    fit[np.invert(empty_scan)], cov[np.invert(empty_scan)] =\
                        normal_eqn_vects(lgndrs, img_pixels, img_weights, img_vars)
            else:
                fit, cov = normal_eqn_vects(lgndrs, img_pixels, img_weights,
                                            img_vars)
            img_fits[im].append(np.expand_dims(fit, 1))
            img_covs[im].append(np.expand_dims(cov, 1))

    Nscans = None
    for im in range(len(img_fits)):
        img_fits[im] = np.concatenate(img_fits[im], 1)
        img_covs[im] = np.concatenate(img_covs[im], 1)
        if Nscans is None:
            Nscans = img_fits[im].shape[0]
        elif Nscans != img_fits[im].shape[0]:
            Nscans = -1
    if Nscans > 0:
        img_fits = np.array(img_fits)
        img_covs = np.array(img_covs)

    if with_scans:
        return img_fits, img_covs
    else:
        return img_fits[:, 0, :, :], img_covs[:, 0, :, :]