Beispiel #1
0
def test_flatimages():
    tmp = np.ones((1000, 100)) * 10.
    x = np.random.rand(500)
    # Create bspline
    spat_bspline1 = bspline.bspline(x, bkspace=0.01 * (np.max(x) - np.min(x)))
    spat_bspline2 = bspline.bspline(x, bkspace=0.01 * (np.max(x) - np.min(x)))
    instant_dict = dict(procflat=tmp,
                        pixelflat=np.ones_like(tmp),
                        flat_model=None,
                        spat_bsplines=np.asarray(
                            [spat_bspline1, spat_bspline2]),
                        spat_id=np.asarray([100, 200]))

    flatImages = flatfield.FlatImages(**instant_dict)
    assert flatImages.flat_model is None

    # I/O
    outfile = data_path('tst_flatimages.fits')
    flatImages.to_file(outfile, overwrite=True)
    _flatImages = flatfield.FlatImages.from_file(outfile)
    # Test
    for key in instant_dict.keys():
        if key == 'spat_bsplines':
            np.array_equal(flatImages[key][0].breakpoints,
                           _flatImages[key][0].breakpoints)
            continue
        if isinstance(instant_dict[key], np.ndarray):
            assert np.array_equal(flatImages[key], _flatImages[key])
        else:
            assert flatImages[key] == _flatImages[key]

    os.remove(outfile)
Beispiel #2
0
def test_bsplinetodict():
    """
    Test for writing a bspline onto a dict (and also reading it out).
    """
    x = np.random.rand(500)

    # Create bspline
    init_bspline = bspline.bspline(x, bkspace=0.01*(np.max(x)-np.min(x)))
    # Write bspline to bspline_dict
    bspline_dict = init_bspline.to_dict()
    # Create bspline from bspline_dict
    bspline_fromdict = bspline.bspline(None, from_dict=bspline_dict)

    assert np.max(np.array(bspline_dict['breakpoints'])-bspline_fromdict.breakpoints) == 0.
Beispiel #3
0
    def spatial_fit(self, norm_spec, spat_coo, median_slit_width, spat_gpm, gpm, debug=False):
        """
        Perform the spatial fit

        Args:
            norm_spec (`numpy.ndarray`_):
            spat_coo (`numpy.ndarray`_):
                Spatial coordinate array
            median_slit_width (:obj:`float`):
            spat_gpm (`numpy.ndarray`_):
            gpm (`numpy.ndarray`_):
            debug (bool, optional):

        Returns:
            tuple: 7 objects
                 - exit_status (int):
                 - spat_coo_data
                 - spat_flat_data
                 - spat_bspl (:class:`pypeit.bspline.bspline.bspline`): Bspline model of the spatial fit.  Used for illumflat
                 - spat_gpm_fit
                 - spat_flat_fit
                 - spat_flat_data_raw
        """

        # Construct the empirical illumination profile
        _spat_gpm, spat_srt, spat_coo_data, spat_flat_data_raw, spat_flat_data \
            = flat.construct_illum_profile(norm_spec, spat_coo, median_slit_width,
                                           spat_gpm=spat_gpm,
                                           spat_samp=self.flatpar['spat_samp'],
                                           illum_iter=self.flatpar['illum_iter'],
                                           illum_rej=self.flatpar['illum_rej'],
                                           debug=debug)

        if self.flatpar['rej_sticky']:
            # Add rejected pixels to gpm
            gpm[spat_gpm] &= (spat_gpm & _spat_gpm)[spat_gpm]

        # Make sure that the normalized and filtered flat is finite!
        if np.any(np.invert(np.isfinite(spat_flat_data))):
            msgs.error('Inifinities in slit illumination function computation!')

        # Determine the breakpoint spacing from the sampling of the
        # spatial coordinates. Use breakpoints at a spacing of a
        # 1/10th of a pixel, but do not allow a bsp smaller than
        # the typical sampling. Use the bspline class to determine
        # the breakpoints:
        spat_bspl = bspline.bspline(spat_coo_data, nord=4,
                                    bkspace=np.fmax(1.0 / median_slit_width / 10.0,
                                                    1.2 * np.median(np.diff(spat_coo_data))))
        # TODO: Can we add defaults to bspline_profile so that we
        # don't have to instantiate invvar and profile_basis
        spat_bspl, spat_gpm_fit, spat_flat_fit, _, exit_status \
            = utils.bspline_profile(spat_coo_data, spat_flat_data,
                                    np.ones_like(spat_flat_data),
                                    np.ones_like(spat_flat_data), nord=4, upper=5.0,
                                    lower=5.0, fullbkpt=spat_bspl.breakpoints)
        # Return
        return exit_status, spat_coo_data, spat_flat_data, spat_bspl, spat_gpm_fit, \
               spat_flat_fit, spat_flat_data_raw
Beispiel #4
0
def test_flatimages():
    tmp = np.ones((1000, 100)) * 10.
    x = np.random.rand(500)
    # Create bspline
    spat_bspline1 = bspline.bspline(x, bkspace=0.01 * (np.max(x) - np.min(x)))
    spat_bspline2 = bspline.bspline(x, bkspace=0.01 * (np.max(x) - np.min(x)))
    instant_dict = dict(
        pixelflat_raw=tmp,
        pixelflat_norm=np.ones_like(tmp),
        pixelflat_model=None,
        pixelflat_spat_bsplines=np.asarray([spat_bspline1, spat_bspline2]),
        pixelflat_spec_illum=None,
        illumflat_raw=tmp,
        illumflat_spat_bsplines=np.asarray([spat_bspline1, spat_bspline2]),
        spat_id=np.asarray([100, 200]),
        PYP_SPEC="specname")

    flatImages = flatfield.FlatImages(**instant_dict)
    assert flatImages.pixelflat_model is None
    assert flatImages.pixelflat_spec_illum is None
    assert flatImages.pixelflat_spat_bsplines is not None

    # I/O
    outfile = data_path('tst_flatimages.fits')
    flatImages.to_master_file(outfile)
    _flatImages = flatfield.FlatImages.from_file(outfile)

    # Test
    for key in instant_dict.keys():
        if key == 'pixelflat_spat_bsplines':
            np.array_equal(flatImages[key][0].breakpoints,
                           _flatImages[key][0].breakpoints)
            continue
        if key == 'illumflat_spat_bsplines':
            np.array_equal(flatImages[key][0].breakpoints,
                           _flatImages[key][0].breakpoints)
            continue
        if isinstance(instant_dict[key], np.ndarray):
            assert np.array_equal(flatImages[key], _flatImages[key])
        else:
            assert flatImages[key] == _flatImages[key]

    os.remove(outfile)
Beispiel #5
0
def test_profile_spat():
    """
    Test that bspline_profile (1) is successful and (2) produces the
    same result for a set of data fit spatially.
    """
    # Files created using `rmtdict` branch (30 Jan 2020)
    files = [data_path('gemini_gnirs_32_{0}_spat_fit.npz'.format(slit)) for slit in [0,1]]
    for f in files:
        d = np.load(f)
        spat_bspl = bspline.bspline(d['spat_coo_data'], nord=4,
                                    bkspace=np.fmax(1.0/d['median_slit_width']/10.0,
                                                    1.2*np.median(np.diff(d['spat_coo_data']))))
        spat_bspl, spat_gpm_fit, spat_flat_fit, _, exit_status \
                = fitting.bspline_profile(d['spat_coo_data'], d['spat_flat_data'],
                                  np.ones_like(d['spat_flat_data']),
                                  np.ones_like(d['spat_flat_data']), nord=4, upper=5.0, lower=5.0,
                                  fullbkpt=spat_bspl.breakpoints, quiet=True)
        assert np.allclose(d['spat_flat_fit'], spat_flat_fit), 'Bad spatial bspline result'
Beispiel #6
0
def test_io():
    """
    Test that bspline_profile (1) is successful and (2) produces the
    same result for a set of data fit spectrally.
    """
    # Files created using `rmtdict` branch (30 Jan 2020)
    files = [
        data_path('gemini_gnirs_32_{0}_spec_fit.npz'.format(slit))
        for slit in [0, 1]
    ]
    logrej = 0.5
    spec_samp_fine = 1.2
    for f in files:
        d = np.load(f)
        spec_bspl, spec_gpm_fit, spec_flat_fit, _, exit_status \
            = fitting.bspline_profile(d['spec_coo_data'], d['spec_flat_data'], d['spec_ivar_data'],
                              np.ones_like(d['spec_coo_data']), ingpm=d['spec_gpm_data'],
                              nord=4, upper=logrej, lower=logrej,
                              kwargs_bspline={'bkspace': spec_samp_fine},
                              kwargs_reject={'groupbadpix': True, 'maxrej': 5}, quiet=True)
    # Write
    ofile = data_path('tst_bspline.fits')
    if os.path.isfile(ofile):
        os.remove(ofile)
    spec_bspl.to_file(ofile)
    # Read
    _spec_bspl = bspline.bspline.from_file(ofile)
    # Check that the data read in is the same
    assert np.array_equal(_spec_bspl.breakpoints,
                          spec_bspl.breakpoints), 'Bad read'
    # Evaluate the bsplines and check that the written and read in data
    # provide the same result
    tmp = spec_bspl.value(np.linspace(0., 1., 100))[0]
    _tmp = _spec_bspl.value(np.linspace(0., 1., 100))[0]
    assert np.array_equal(tmp, _tmp), 'Bad bspline evaluate'

    # Test overwrite
    _spec_bspl.to_file(ofile, overwrite=True)

    # None
    bspl = bspline.bspline(None)
    bspl.to_file(ofile, overwrite=True)

    os.remove(ofile)
Beispiel #7
0
def standard_sensfunc(wave, flux, ivar, mask_bad, flux_std, mask_balm=None, mask_tell=None,
                 maxiter=35, upper=2.0, lower=2.0, polyorder=5, balm_mask_wid=50., nresln=20., telluric=True,
                 resolution=2700., polycorrect=True, debug=False, polyfunc=False, show_QA=False):
    """
    Generate a sensitivity function based on observed flux and standard spectrum.

    Parameters
    ----------
    wave : ndarray
      wavelength as observed
    flux : ndarray
      counts/s as observed
    ivar : ndarray
      inverse variance
    flux_std : Quantity array
      standard star true flux (erg/s/cm^2/A)
    msk_bad : ndarray
      mask for bad pixels. True is good.
    msk_star: ndarray
      mask for hydrogen recombination lines. True is good.
    msk_tell:ndarray
      mask for telluric regions. True is good.
    maxiter : integer
      maximum number of iterations for polynomial fit
    upper : integer
      number of sigma for rejection in polynomial
    lower : integer
      number of sigma for rejection in polynomial
    polyorder : integer
      order of polynomial fit
    balm_mask_wid: float
      in units of angstrom
      Mask parameter for Balmer absorption. A region equal to
      balm_mask_wid is masked.
    resolution: integer/float.
      spectra resolution
      This paramters should be removed in the future. The resolution should be estimated from spectra directly.
    debug : bool
      if True shows some dubugging plots

    Returns
    -------
    sensfunc
    """
    # Create copy of the arrays to avoid modification
    wave_obs = wave.copy()
    flux_obs = flux.copy()
    ivar_obs = ivar.copy()
    # preparing arrays
    if np.any(np.invert(np.isfinite(ivar_obs))):
        msgs.warn("NaN are present in the inverse variance")

    # check masks
    if mask_tell is None:
        mask_tell = np.ones_like(wave_obs,dtype=bool)
    if mask_balm is None:
        mask_balm = np.ones_like(wave_obs, dtype=bool)

    # Removing outliers
    # Calculate log of flux_obs setting a floor at TINY
    logflux_obs = 2.5 * np.log10(np.maximum(flux_obs, TINY))
    # Set a fix value for the variance of logflux
    logivar_obs = np.ones_like(logflux_obs) * (10.0 ** 2)
    # Calculate log of flux_std model setting a floor at TINY
    logflux_std = 2.5 * np.log10(np.maximum(flux_std, TINY))
    # Calculate ratio setting a floor at MAGFUNC_MIN and a ceiling at
    # MAGFUNC_MAX
    magfunc = logflux_std - logflux_obs
    magfunc = np.maximum(np.minimum(magfunc, MAGFUNC_MAX), MAGFUNC_MIN)
    msk_magfunc = (magfunc < 0.99 * MAGFUNC_MAX) & (magfunc > 0.99 * MAGFUNC_MIN)

    # Define two new masks, True is good and False is masked pixel
    # mask for all bad pixels on sensfunc
    masktot = mask_bad & msk_magfunc & np.isfinite(ivar_obs) & np.isfinite(logflux_obs) & np.isfinite(logflux_std)
    logivar_obs[np.invert(masktot)] = 0.0
    # mask used for polynomial fit
    msk_fit_sens = masktot & mask_tell & mask_balm

    # Polynomial fitting to derive a smooth sensfunc (i.e. without telluric)
    pypeitFit = fitting.robust_fit(wave_obs[msk_fit_sens], magfunc[msk_fit_sens], polyorder,
                                             function='polynomial', maxiter=maxiter,
                                             lower=lower, upper=upper,
                                             groupbadpix=False,
                                             grow=0, sticky=True, use_mad=True)
    magfunc_poly = pypeitFit.eval(wave_obs)

    # Polynomial corrections on Hydrogen Recombination lines
    if ((sum(msk_fit_sens) > 0.5 * len(msk_fit_sens)) & polycorrect):
        ## Only correct Hydrogen Recombination lines with polyfit in the telluric free region
        balmer_clean = np.zeros_like(wave_obs, dtype=bool)
        # Commented out the bluest recombination lines since they are weak for spectroscopic standard stars.
        #836.4, 3969.6, 3890.1, 4102.8, 4102.8, 4341.6, 4862.7,   \
        lines_hydrogen = np.array([5407.0, 6564.6, 8224.8, 8239.2, 8203.6, 8440.3, 8469.6, 8504.8, 8547.7, 8600.8, \
                                   8667.4, 8752.9, 8865.2, 9017.4, 9229.0, 10049.4, 10938.1, 12818.1, 21655.0])
        for line_hydrogen in lines_hydrogen:
            ihydrogen = np.abs(wave_obs - line_hydrogen) <= balm_mask_wid
            balmer_clean[ihydrogen] = True
        msk_clean = ((balmer_clean) | (magfunc == MAGFUNC_MAX) | (magfunc == MAGFUNC_MIN)) & \
                    (magfunc_poly > MAGFUNC_MIN) & (magfunc_poly < MAGFUNC_MAX)
        magfunc[msk_clean] = magfunc_poly[msk_clean]
        msk_badpix = np.isfinite(ivar_obs) & (ivar_obs > 0)
        magfunc[np.invert(msk_badpix)] = magfunc_poly[np.invert(msk_badpix)]
    else:
        ## if half more than half of your spectrum is masked (or polycorrect=False) then do not correct it with polyfit
        msgs.warn('No polynomial corrections performed on Hydrogen Recombination line regions')

    if not telluric:
        # Apply mask to ivar
        #logivar_obs[~msk_fit_sens] = 0.

        # ToDo
        # Compute an effective resolution for the standard. This could be improved
        # to setup an array of breakpoints based on the resolution. At the
        # moment we are using only one number
        msgs.work("Should pull resolution from arc line analysis")
        msgs.work("At the moment the resolution is taken as the PixelScale")
        msgs.work("This needs to be changed!")
        std_pix = np.median(np.abs(wave_obs - np.roll(wave_obs, 1)))
        std_res = np.median(wave_obs/resolution) # median resolution in units of Angstrom.
        #std_res = std_pix
        #resln = std_res
        if (nresln * std_res) < std_pix:
            msgs.warn("Bspline breakpoints spacing shoud be larger than 1pixel")
            msgs.warn("Changing input nresln to fix this")
            nresln = std_res / std_pix

        # Fit magfunc with bspline
        kwargs_bspline = {'bkspace': std_res * nresln}
        kwargs_reject = {'maxrej': 5}
        msgs.info("Initialize bspline for flux calibration")
#        init_bspline = pydl.bspline(wave_obs, bkspace=kwargs_bspline['bkspace'])
        init_bspline = bspline.bspline(wave_obs, bkspace=kwargs_bspline['bkspace'])
        fullbkpt = init_bspline.breakpoints

        # TESTING turning off masking for now
        # remove masked regions from breakpoints
        msk_obs = np.ones_like(wave_obs).astype(bool)
        msk_obs[np.invert(masktot)] = False
        msk_bkpt = interpolate.interp1d(wave_obs, msk_obs, kind='nearest',
                                        fill_value='extrapolate')(fullbkpt)
        init_breakpoints = fullbkpt[msk_bkpt > 0.999]

        # init_breakpoints = fullbkpt
        msgs.info("Bspline fit on magfunc. ")
        bset1, bmask = fitting.iterfit(wave_obs, magfunc, invvar=logivar_obs, inmask=msk_fit_sens, upper=upper, lower=lower,
                                    fullbkpt=init_breakpoints, maxiter=maxiter, kwargs_bspline=kwargs_bspline,
                                    kwargs_reject=kwargs_reject)
        logfit1, _ = bset1.value(wave_obs)
        logfit_bkpt, _ = bset1.value(init_breakpoints)

        if debug:
            # Check for calibration
            plt.figure(1)
            plt.plot(wave_obs, magfunc, drawstyle='steps-mid', color='black', label='magfunc')
            plt.plot(wave_obs, logfit1, color='cornflowerblue', label='logfit1')
            plt.plot(wave_obs[np.invert(msk_fit_sens)], magfunc[np.invert(msk_fit_sens)], '+', color='red', markersize=5.0,
                     label='masked magfunc')
            plt.plot(wave_obs[np.invert(msk_fit_sens)], logfit1[np.invert(msk_fit_sens)], '+', color='red', markersize=5.0,
                     label='masked logfit1')
            plt.plot(init_breakpoints, logfit_bkpt, '.', color='green', markersize=4.0, label='breakpoints')
            plt.plot(init_breakpoints, np.interp(init_breakpoints, wave_obs, magfunc), '.', color='green',
                     markersize=4.0,
                     label='breakpoints')
            plt.plot(wave_obs, 1.0 / np.sqrt(logivar_obs), color='orange', label='sigma')
            plt.legend()
            plt.xlabel('Wavelength [ang]')
            plt.ylim(0.0, 1.2 * MAGFUNC_MAX)
            plt.title('1st Bspline fit')
            plt.show()
        # Create sensitivity function
        magfunc = np.maximum(np.minimum(logfit1, MAGFUNC_MAX), MAGFUNC_MIN)
        if ((sum(msk_fit_sens) > 0.5 * len(msk_fit_sens)) & polycorrect):
            msk_clean = ((magfunc == MAGFUNC_MAX) | (magfunc == MAGFUNC_MIN)) & \
                        (magfunc_poly > MAGFUNC_MIN) & (magfunc_poly<MAGFUNC_MAX)
            magfunc[msk_clean] = magfunc_poly[msk_clean]
            msk_badpix = np.isfinite(ivar_obs) & (ivar_obs>0)
            magfunc[np.invert(msk_badpix)] = magfunc_poly[np.invert(msk_badpix)]
        else:
            ## if half more than half of your spectrum is masked (or polycorrect=False) then do not correct it with polyfit
            msgs.warn('No polynomial corrections performed on Hydrogen Recombination line regions')

    # Calculate sensfunc
    if polyfunc:
        sensfunc = 10.0 ** (0.4 * magfunc_poly)
        magfunc = magfunc_poly
    else:
        sensfunc = 10.0 ** (0.4 * magfunc)

    if debug:
        plt.figure()
        magfunc_raw = logflux_std - logflux_obs
        plt.plot(wave_obs[masktot],magfunc_raw[masktot] , 'k-',lw=3,label='Raw Magfunc')
        plt.plot(wave_obs[masktot],magfunc_poly[masktot] , 'c-',lw=3,label='Polynomial Fit')
        plt.plot(wave_obs[np.invert(mask_tell)], magfunc_raw[np.invert(mask_tell)], 's',
                 color='0.7',label='Telluric Region')
        plt.plot(wave_obs[np.invert(mask_balm)], magfunc_raw[np.invert(mask_balm)], 'r+',label='Recombination Line region')
        plt.plot(wave_obs[masktot], magfunc[masktot],'b-',label='Final Magfunc')
        plt.legend(fancybox=True, shadow=True)
        plt.xlim([0.995*np.min(wave_obs[masktot]),1.005*np.max(wave_obs[masktot])])
        plt.ylim([0.,1.2*np.max(magfunc[masktot])])
        plt.show()
        plt.close()

    return sensfunc, masktot
Beispiel #8
0
def bspline_profile(xdata,
                    ydata,
                    invvar,
                    profile_basis,
                    ingpm=None,
                    upper=5,
                    lower=5,
                    maxiter=25,
                    nord=4,
                    bkpt=None,
                    fullbkpt=None,
                    relative=None,
                    kwargs_bspline={},
                    kwargs_reject={},
                    quiet=False):
    """
    Fit a B-spline in the least squares sense with rejection to the
    provided data and model profiles.

    .. todo::
        Fully describe procedure.

    Parameters
    ----------
    xdata : `numpy.ndarray`_
        Independent variable.
    ydata : `numpy.ndarray`_
        Dependent variable.
    invvar : `numpy.ndarray`_
        Inverse variance of `ydata`.
    profile_basis : `numpy.ndarray`_
        Model profiles.
    ingpm : `numpy.ndarray`_, optional
        Input good-pixel mask. Values to fit in ``ydata`` should be
        True.
    upper : :obj:`int`, :obj:`float`, optional
        Upper rejection threshold in units of sigma, defaults to 5
        sigma.
    lower : :obj:`int`, :obj:`float`, optional
        Lower rejection threshold in units of sigma, defaults to 5
        sigma.
    maxiter : :obj:`int`, optional
        Maximum number of rejection iterations, default 10. Set this
        to zero to disable rejection.
    nord : :obj:`int`, optional
        Order of B-spline fit
    bkpt : `numpy.ndarray`_, optional
        Array of breakpoints to be used for the b-spline
    fullbkpt : `numpy.ndarray`_, optional
        Full array of breakpoints to be used for the b-spline,
        without letting the b-spline class append on any extra bkpts
    relative : `numpy.ndarray`_, optional
        Array of integer indices to be used for computing the reduced
        chi^2 of the fits, which then is used as a scale factor for
        the upper,lower rejection thresholds
    kwargs_bspline : :obj:`dict`, optional
        Keyword arguments used to instantiate
        :class:`pypeit.bspline.bspline`
    kwargs_reject : :obj:`dict`, optional
        Keyword arguments passed to :func:`pypeit.core.pydl.djs_reject`
    quiet : :obj:`bool`, optional
        Suppress output to the screen

    Returns
    -------
    sset : :class:`pypeit.bspline.bspline`
        Result of the fit.
    gpm : `numpy.ndarray`_
        Output good-pixel mask which the same size as ``xdata``. The
        values in this array for the corresponding data are not used in
        the fit, either because the input data was masked or the data
        were rejected during the fit, if they are False. Data
        rejected during the fit (if rejection is performed) are::

            rejected = ingpm & np.logical_not(gpm)

    yfit : `numpy.ndarray`_
        The best-fitting model; shape is the same as ``xdata``.
    reduced_chi : :obj:`float`
        Reduced chi-square of the best-fitting model.
    exit_status : :obj:`int`
        Indication of the success/failure of the fit.  Values are:

            - 0 = fit exited cleanly
            - 1 = maximum iterations were reached
            - 2 = all points were masked
            - 3 = all break points were dropped
            - 4 = Number of good data points fewer than nord

    """
    # Checks
    nx = xdata.size
    if ydata.size != nx:
        msgs.error('Dimensions of xdata and ydata do not agree.')

    # TODO: invvar and profile_basis should be optional

    # ToDO at the moment invvar is a required variable input
    #    if invvar is not None:
    #        if invvar.size != nx:
    #            raise ValueError('Dimensions of xdata and invvar do not agree.')
    #        else:
    #            #
    #            # This correction to the variance makes it the same
    #            # as IDL's variance()
    #            #
    #            var = ydata.var()*(float(nx)/float(nx-1))
    #            if var == 0:
    #                var = 1.0
    #            invvar = np.ones(ydata.shape, dtype=ydata.dtype)/var

    npoly = int(profile_basis.size / nx)
    if profile_basis.size != nx * npoly:
        msgs.error(
            'Profile basis is not a multiple of the number of data points.')

    # Init
    yfit = np.zeros(ydata.shape)
    reduced_chi = 0.

    # TODO: Instanting these place-holder arrays can be expensive.  Can we avoid doing this?
    outmask = True if invvar.size == 1 else np.ones(invvar.shape, dtype=bool)

    if ingpm is None:
        ingpm = invvar > 0

    if not quiet:
        termwidth = 80 - 13
        msgs.info('B-spline fit:')
        msgs.info('    npoly = {0} profile basis functions'.format(npoly))
        msgs.info('    ngood = {0}/{1} measurements'.format(
            np.sum(ingpm), ingpm.size))
        msgs.info(' {0:>4}  {1:>8}  {2:>7}  {3:>6} '.format(
            'Iter', 'Chi^2', 'N Rej', 'R. Fac').center(termwidth, '*'))
        hlinestr = ' {0}  {1}  {2}  {3} '.format('-' * 4, '-' * 8, '-' * 7,
                                                 '-' * 6)
        nullval = '  {0:>8}  {1:>7}  {2:>6} '.format('-' * 2, '-' * 2, '-' * 2)
        msgs.info(hlinestr.center(termwidth))

    maskwork = outmask & ingpm & (invvar > 0)
    if not maskwork.any():
        msgs.error('No valid data points in bspline_profile!.')

    # Init bspline class
    sset = bspline.bspline(xdata[maskwork],
                           nord=nord,
                           npoly=npoly,
                           bkpt=bkpt,
                           fullbkpt=fullbkpt,
                           funcname='Bspline longslit special',
                           **kwargs_bspline)
    if maskwork.sum() < sset.nord:
        if not quiet:
            msgs.warn('Number of good data points fewer than nord.')
        # TODO: Why isn't maskwork returned?
        return sset, outmask, yfit, reduced_chi, 4

    # This was checked in detail against IDL for identical inputs
    # KBW: Tried a few things and this was about as fast as you can get.
    outer = np.outer(np.ones(nord, dtype=float), profile_basis.flatten('F')).T
    action_multiple = outer.reshape((nx, npoly * nord), order='F')
    # --------------------
    # Iterate spline fit
    iiter = 0
    error = -1  # Indicates that the fit should be done
    qdone = False  # True if rejection iterations are done
    exit_status = 0
    relative_factor = 1.0
    nrel = 0 if relative is None else len(relative)
    # TODO: Why do we need both maskwork and tempin?
    tempin = np.copy(ingpm)
    while (error != 0
           or qdone is False) and iiter <= maxiter and exit_status == 0:
        ngood = maskwork.sum()
        goodbk = sset.mask.nonzero()[0]
        if ngood <= 1 or not sset.mask.any():
            sset.coeff[:] = 0.
            exit_status = 2  # This will end iterations
        else:
            # Do the fit. Return values from workit for error are as follows:
            #    0 if fit is good
            #   -1 if some breakpoints are masked, so try the fit again
            #   -2 if everything is screwed

            # we'll do the fit right here..............
            if error != 0:
                bf1, laction, uaction = sset.action(xdata)
                if np.any(bf1 == -2) or bf1.size != nx * nord:
                    msgs.error("BSPLINE_ACTION failed!")
                action = np.copy(action_multiple)
                for ipoly in range(npoly):
                    action[:, np.arange(nord) * npoly + ipoly] *= bf1
                del bf1  # Clear the memory

            if np.any(np.logical_not(np.isfinite(action))):
                msgs.error(
                    'Infinities in action matrix.  B-spline fit faults.')

            error, yfit = sset.workit(xdata, ydata, invvar * maskwork, action,
                                      laction, uaction)

        iiter += 1

        if error == -2:
            if not quiet:
                msgs.warn('All break points lost!!  Bspline fit failed.')
            exit_status = 3
            return sset, np.zeros(xdata.shape, dtype=bool), np.zeros(xdata.shape), reduced_chi, \
                   exit_status

        if error != 0:
            if not quiet:
                msgs.info(
                    (' {0:4d}'.format(iiter) + nullval).center(termwidth))
            continue

        # Iterate the fit -- next rejection iteration
        chi_array = (ydata - yfit) * np.sqrt(invvar * maskwork)
        reduced_chi = np.sum(np.square(chi_array)) / (ngood - npoly *
                                                      (len(goodbk) + nord) - 1)

        relative_factor = 1.0
        if relative is not None:
            this_chi2 = reduced_chi if nrel == 1 \
                else np.sum(np.square(chi_array[relative])) \
                     / (nrel - (len(goodbk) + nord) - 1)
            relative_factor = max(np.sqrt(this_chi2), 1.0)

        # Rejection

        # TODO: JFH by setting ingpm to be tempin which is maskwork, we
        #  are basically implicitly enforcing sticky rejection here. See
        #  djs_reject.py. I'm leaving this as is for consistency with
        #  the IDL version, but this may require further consideration.
        #  I think requiring sticky to be set is the more transparent
        #  behavior.
        maskwork, qdone = pydl.djs_reject(ydata,
                                          yfit,
                                          invvar=invvar,
                                          inmask=tempin,
                                          outmask=maskwork,
                                          upper=upper * relative_factor,
                                          lower=lower * relative_factor,
                                          **kwargs_reject)
        tempin = np.copy(maskwork)
        if not quiet:
            msgs.info(' {0:4d}  {1:8.3f}  {2:7d}  {3:6.2f} '.format(
                iiter, reduced_chi, np.sum(maskwork == 0),
                relative_factor).center(termwidth))

    if iiter == (maxiter + 1):
        exit_status = 1

    # Exit status:
    #    0 = fit exited cleanly
    #    1 = maximum iterations were reached
    #    2 = all points were masked
    #    3 = all break points were dropped
    #    4 = Number of good data points fewer than nord

    if not quiet:
        msgs.info(' {0:>4}  {1:8.3f}  {2:7d}  {3:6.2f} '.format(
            'DONE', reduced_chi, np.sum(maskwork == 0),
            relative_factor).center(termwidth))
        msgs.info('*' * termwidth)

    # Finish
    # TODO: Why not return maskwork directly
    outmask = np.copy(maskwork)
    # Return
    return sset, outmask, yfit, reduced_chi, exit_status
Beispiel #9
0
def iterfit(xdata,
            ydata,
            invvar=None,
            inmask=None,
            upper=5,
            lower=5,
            x2=None,
            maxiter=10,
            nord=4,
            bkpt=None,
            fullbkpt=None,
            kwargs_bspline={},
            kwargs_reject={}):
    """Iteratively fit a b-spline set to data, with rejection. This is a utility function that allows
    the bspline to use via a direct function call.

    Parameters
    ----------
    xdata : :class:`numpy.ndarray`
        Independent variable.
    ydata : :class:`numpy.ndarray`
        Dependent variable.
    invvar : :class:`numpy.ndarray`
        Inverse variance of `ydata`.  If not set, it will be calculated based
        on the standard deviation.
    upper : :class:`int` or :class:`float`
        Upper rejection threshold in units of sigma, defaults to 5 sigma.
    lower : :class:`int` or :class:`float`
        Lower rejection threshold in units of sigma, defaults to 5 sigma.
    x2 : :class:`numpy.ndarray`, optional
        Orthogonal dependent variable for 2d fits.
    maxiter : :class:`int`, optional
        Maximum number of rejection iterations, default 10.  Set this to
        zero to disable rejection.

    Returns
    -------
    :func:`tuple`
        A tuple containing the fitted bspline object and an output mask.
    """
    # from .math import djs_reject
    nx = xdata.size
    if ydata.size != nx:
        raise ValueError('Dimensions of xdata and ydata do not agree.')
    if invvar is not None:
        if invvar.size != nx:
            raise ValueError('Dimensions of xdata and invvar do not agree.')
    else:
        #
        # This correction to the variance makes it the same
        # as IDL's variance()
        #
        var = ydata.var() * (float(nx) / float(nx - 1))
        if var == 0:
            var = 1.0
        invvar = np.ones(ydata.shape, dtype=ydata.dtype) / var

    if inmask is None:
        inmask = invvar > 0.0

    if x2 is not None:
        if x2.size != nx:
            raise ValueError('Dimensions of xdata and x2 do not agree.')
    yfit = np.zeros(ydata.shape)
    if invvar.size == 1:
        outmask = True
    else:
        outmask = np.ones(invvar.shape, dtype='bool')
    xsort = xdata.argsort()
    maskwork = (outmask & inmask & (invvar > 0.0))[xsort]
    if 'oldset' in kwargs_bspline:
        sset = kwargs_bspline['oldset']
        sset.mask[:] = True
        sset.coeff[:] = 0.
    else:
        if not maskwork.any():
            raise ValueError('No valid data points.')
            # return (None,None)
        # JFH comment this out for now
        #        if 'fullbkpt' in kwargs:
        #            fullbkpt = kwargs['fullbkpt']
        else:
            sset = bspline.bspline(xdata[xsort[maskwork]],
                                   nord=nord,
                                   bkpt=bkpt,
                                   fullbkpt=fullbkpt,
                                   **kwargs_bspline)
            if maskwork.sum() < sset.nord:
                print('Number of good data points fewer than nord.')
                return (sset, outmask)
            if x2 is not None:
                if 'xmin' in kwargs_bspline:
                    xmin = kwargs_bspline['xmin']
                else:
                    xmin = x2.min()
                if 'xmax' in kwargs_bspline:
                    xmax = kwargs_bspline['xmax']
                else:
                    xmax = x2.max()
                if xmin == xmax:
                    xmax = xmin + 1
                sset.xmin = xmin
                sset.xmax = xmax
                if 'funcname' in kwargs_bspline:
                    sset.funcname = kwargs_bspline['funcname']
    xwork = xdata[xsort]
    ywork = ydata[xsort]
    invwork = invvar[xsort]
    if x2 is not None:
        x2work = x2[xsort]
    else:
        x2work = None
    iiter = 0
    error = -1
    qdone = False
    while (error != 0 or qdone is False) and iiter <= maxiter:
        goodbk = sset.mask.nonzero()[0]
        if maskwork.sum() <= 1 or not sset.mask.any():
            sset.coeff[:] = 0.
            iiter = maxiter + 1  # End iterations
        else:
            if 'requiren' in kwargs_bspline:
                i = 0
                while xwork[i] < sset.breakpoints[goodbk[
                        sset.nord]] and i < nx - 1:
                    i += 1
                ct = 0
                for ileft in range(sset.nord, sset.mask.sum() - sset.nord + 1):
                    while (xwork[i] >= sset.breakpoints[goodbk[ileft]]
                           and xwork[i] < sset.breakpoints[goodbk[ileft + 1]]
                           and i < nx - 1):
                        ct += invwork[i] * maskwork[i] > 0
                        i += 1
                    if ct >= kwargs_bspline['requiren']:
                        ct = 0
                    else:
                        sset.mask[goodbk[ileft]] = False
            error, yfit = sset.fit(xwork, ywork, invwork * maskwork, x2=x2work)
        iiter += 1
        inmask_rej = maskwork
        if error == -2:

            return (sset, outmask)
        elif error == 0:
            # ToDO JFH by setting inmask to be tempin which is maskwork, we are basically implicitly enforcing sticky rejection
            # here. See djs_reject.py. I'm leaving this as is for consistency with the IDL version, but this may require
            # further consideration. I think requiring stick to be set is the more transparent behavior.
            maskwork, qdone = pydl.djs_reject(ywork,
                                              yfit,
                                              invvar=invwork,
                                              inmask=inmask_rej,
                                              outmask=maskwork,
                                              upper=upper,
                                              lower=lower,
                                              **kwargs_reject)
        else:
            pass
    outmask[xsort] = maskwork
    temp = yfit
    yfit[xsort] = temp
    return (sset, outmask)
Beispiel #10
0
    def fit(self, debug=False):
        """
        Construct a model of the flat-field image.

        For this method to work, :attr:`rawflatimg` must have been
        previously constructed; see :func:`build_pixflat`.

        The method loops through all slits provided by the :attr:`slits`
        object, except those that have been masked (i.e., slits with
        ``self.slits.mask == True`` are skipped).  For each slit:

            - Collapse the flat-field data spatially using the
              wavelength coordinates provided by the fit to the arc-line
              traces (:class:`pypeit.wavetilts.WaveTilts`), and fit the
              result with a bspline.  This provides the
              spatially-averaged spectral response of the instrument.
              The data used in the fit is trimmed toward the slit
              spatial center via the ``slit_trim`` parameter in
              :attr:`flatpar`.
            - Use the bspline fit to construct and normalize out the
              spectral response.
            - Collapse the normalized flat-field data spatially using a
              coordinate system defined by the left slit edge.  The data
              included in the spatial (illumination) profile calculation
              is expanded beyond the nominal slit edges using the
              ``slit_illum_pad`` parameter in :attr:`flatpar`.  The raw,
              collapsed data is then median filtered (see ``spat_samp``
              in :attr:`flatpar`) and Gaussian filtered; see
              :func:`pypeit.core.flat.illum_filter`.  This creates an
              empirical, highly smoothed representation of the
              illumination profile that is fit with a bspline using
              the :func:`spatial_fit` method.  The
              construction of the empirical illumination profile (i.e.,
              before the bspline fitting) can be done iteratively, where
              each iteration sigma-clips outliers; see the
              ``illum_iter`` and ``illum_rej`` parameters in
              :attr:`flatpar` and
              :func:`pypeit.core.flat.construct_illum_profile`.
            - If requested, the 1D illumination profile is used to
              "tweak" the slit edges by offsetting them to a threshold
              of the illumination peak to either side of the slit center
              (see ``tweak_slits_thresh`` in :attr:`flatpar`), up to a
              maximum allowed shift from the existing slit edge (see
              ``tweak_slits_maxfrac`` in :attr:`flatpar`).  See
              :func:`pypeit.core.tweak_slit_edges`.  If tweaked, the
              :func:`spatial_fit` is repeated to place it on the tweaked
              slits reference frame.
            - Use the bspline fit to construct the 2D illumination image
              (:attr:`msillumflat`) and normalize out the spatial
              response.
            - Fit the residuals of the flat-field data that has been
              independently normalized for its spectral and spatial
              response with a 2D bspline-polynomial fit.  The order of
              the polynomial has been optimized via experimentation; it
              can be changed but you should use extreme caution when
              doing so (see ``twod_fit_npoly``).  The multiplication of
              the 2D spectral response, 2D spatial response, and joint
              2D fit to the high-order residuals define the final flat
              model (:attr:`flat_model`).
            - Finally, the pixel-to-pixel response of the instrument is
              defined as the ratio of the raw flat data to the
              best-fitting flat-field model (:attr:`mspixelflat`)

        This method is the primary method that builds the
        :class:`FlatField` instance, constructing :attr:`mspixelflat`,
        :attr:`msillumflat`, and :attr:`flat_model`.  All of these
        attributes are altered internally.  If the slit edges are to be
        tweaked using the 1D illumination profile (``tweak_slits`` in
        :attr:`flatpar`), the tweaked slit edge arrays in the internal
        :class:`pypeit.edgetrace.SlitTraceSet` object, :attr:`slits`,
        are also altered.

        Used parameters from :attr:`flatpar`
        (:class:`pypeit.par.pypeitpar.FlatFieldPar`) are
        ``spec_samp_fine``, ``spec_samp_coarse``, ``spat_samp``,
        ``tweak_slits``, ``tweak_slits_thresh``,
        ``tweak_slits_maxfrac``, ``rej_sticky``, ``slit_trim``,
        ``slit_illum_pad``, ``illum_iter``, ``illum_rej``, and
        ``twod_fit_npoly``, ``saturated_slits``.

        **Revision History**:

            - 11-Mar-2005  First version written by Scott Burles.
            - 2005-2018    Improved by J. F. Hennawi and J. X. Prochaska
            - 3-Sep-2018 Ported to python by J. F. Hennawi and significantly improved

        Args:
            debug (:obj:`bool`, optional):
                Show plots useful for debugging. This will block
                further execution of the code until the plot windows
                are closed.

        """
        # TODO: break up this function!  Can it be partitioned into a series of "core" methods?
        # TODO: JFH I wrote all this code and will have to maintain it and I don't want to see it broken up.
        # TODO: JXP This definitely needs breaking up..

        # Init
        self.list_of_spat_bsplines = []

        # Set parameters (for convenience;
        spec_samp_fine = self.flatpar['spec_samp_fine']
        spec_samp_coarse = self.flatpar['spec_samp_coarse']
        tweak_slits = self.flatpar['tweak_slits']
        tweak_slits_thresh = self.flatpar['tweak_slits_thresh']
        tweak_slits_maxfrac = self.flatpar['tweak_slits_maxfrac']
        # If sticky, points rejected at each stage (spec, spat, 2d) are
        # propagated to the next stage
        sticky = self.flatpar['rej_sticky']
        trim = self.flatpar['slit_trim']
        pad = self.flatpar['slit_illum_pad']
        # Iteratively construct the illumination profile by rejecting outliers
        npoly = self.flatpar['twod_fit_npoly']
        saturated_slits = self.flatpar['saturated_slits']

        # Setup images
        nspec, nspat = self.rawflatimg.image.shape
        rawflat = self.rawflatimg.image
        # Good pixel mask
        gpm = np.ones_like(rawflat, dtype=bool) if self.rawflatimg.bpm is None else (
                1-self.rawflatimg.bpm).astype(bool)

        # Flat-field modeling is done in the log of the counts
        flat_log = np.log(np.fmax(rawflat, 1.0))
        gpm_log = (rawflat > 1.0) & gpm
        # set errors to just be 0.5 in the log
        ivar_log = gpm_log.astype(float)/0.5**2

        # Other setup
        nonlinear_counts = self.spectrograph.nonlinear_counts(self.rawflatimg.detector)

        # TODO -- JFH -- CONFIRM THIS SHOULD BE ON INIT
        # It does need to be *all* of the slits
        median_slit_widths = np.median(self.slits.right_init - self.slits.left_init, axis=0)

        if tweak_slits:
            # NOTE: This copies the input slit edges to a set that can be tweaked.
            self.slits.init_tweaked()

        # TODO: This needs to include a padding check
        # Construct three versions of the slit ID image, all of unmasked slits!
        #   - an image that uses the padding defined by self.slits
        slitid_img_init = self.slits.slit_img(initial=True)
        #   - an image that uses the extra padding defined by
        #     self.flatpar. This was always 5 pixels in the previous
        #     version.
        padded_slitid_img = self.slits.slit_img(initial=True, pad=pad)
        #   - and an image that trims the width of the slit using the
        #     parameter in self.flatpar. This was always 3 pixels in
        #     the previous version.
        # TODO: Fix this for when trim is a tuple
        trimmed_slitid_img = self.slits.slit_img(pad=-trim, initial=True)

        # Prep for results
        self.mspixelflat = np.ones_like(rawflat)
        self.msillumflat = np.ones_like(rawflat)
        self.flat_model = np.zeros_like(rawflat)

        # Allocate work arrays only once
        spec_model = np.ones_like(rawflat)
        norm_spec = np.ones_like(rawflat)
        norm_spec_spat = np.ones_like(rawflat)
        twod_model = np.ones_like(rawflat)

        # #################################################
        # Model each slit independently
        for slit_idx, slit_spat in enumerate(self.slits.spat_id):
            # Is this a good slit??
            if self.slits.mask[slit_idx] != 0:
                msgs.info('Skipping bad slit: {}'.format(slit_spat))
                self.list_of_spat_bsplines.append(bspline.bspline(None))
                continue

            msgs.info('Modeling the flat-field response for slit spat_id={}: {}/{}'.format(
                        slit_spat, slit_idx+1, self.slits.nslits))

            # Find the pixels on the initial slit
            onslit_init = slitid_img_init == slit_spat

            # Check for saturation of the flat. If there are not enough
            # pixels do not attempt a fit, and continue to the next
            # slit.
            # TODO: set the threshold to a parameter?
            good_frac = np.sum(onslit_init & (rawflat < nonlinear_counts))/np.sum(onslit_init)
            if good_frac < 0.5:
                common_message = 'To change the behavior, use the \'saturated_slits\' parameter ' \
                                 'in the \'flatfield\' parameter group; see here:\n\n' \
                                 'https://pypeit.readthedocs.io/en/latest/pypeit_par.html \n\n' \
                                 'You could also choose to use a different flat-field image ' \
                                 'for this calibration group.'
                if saturated_slits == 'crash':
                    msgs.error('Only {:4.2f}'.format(100*good_frac)
                               + '% of the pixels on slit {0} are not saturated.  '.format(slit_spat)
                               + 'Selected behavior was to crash if this occurred.  '
                               + common_message)
                elif saturated_slits == 'mask':
                    self.slits.mask[slit_idx] = self.slits.bitmask.turn_on(self.slits.mask[slit_idx], 'BADFLATCALIB')
                    msgs.warn('Only {:4.2f}'.format(100*good_frac)
                                                + '% of the pixels on slit {0} are not saturated.  '.format(slit_spat)
                              + 'Selected behavior was to mask this slit and continue with the '
                              + 'remainder of the reduction, meaning no science data will be '
                              + 'extracted from this slit.  ' + common_message)
                elif saturated_slits == 'continue':
                    self.slits.mask[slit_idx] = self.slits.bitmask.turn_on(self.slits.mask[slit_idx], 'SKIPFLATCALIB')
                    msgs.warn('Only {:4.2f}'.format(100*good_frac)
                              + '% of the pixels on slit {0} are not saturated.  '.format(slit_spat)
                              + 'Selected behavior was to simply continue, meaning no '
                              + 'field-flatting correction will be applied to this slit but '
                              + 'pypeit will attempt to extract any objects found on this slit.  '
                              + common_message)
                else:
                    # Should never get here
                    raise NotImplementedError('Unknown behavior for saturated slits: {0}'.format(
                                              saturated_slits))
                self.list_of_spat_bsplines.append(bspline.bspline(None))
                continue

            # Demand at least 10 pixels per row (on average) per degree
            # of the polynomial.
            # NOTE: This is not used until the 2D fit. Defined here to
            # be close to the definition of ``onslit``.
            if npoly is None:
                # Approximate number of pixels sampling each spatial pixel
                # for this (original) slit.
                npercol = np.fmax(np.floor(np.sum(onslit_init)/nspec),1.0)
                npoly  = np.clip(7, 1, int(np.ceil(npercol/10.)))
            
            # TODO: Always calculate the optimized `npoly` and warn the
            #  user if npoly is provided but higher than the nominal
            #  calculation?

            # Create an image with the spatial coordinates relative to the left edge of this slit
            spat_coo_init = self.slits.spatial_coordinate_image(slitidx=slit_idx, full=True, initial=True)

            # Find pixels on the padded and trimmed slit coordinates
            onslit_padded = padded_slitid_img == slit_spat
            onslit_trimmed = trimmed_slitid_img == slit_spat

            # ----------------------------------------------------------
            # Collapse the slit spatially and fit the spectral function
            # TODO: Put this stuff in a self.spectral_fit method?

            # Create the tilts image for this slit
            # TODO -- JFH Confirm the sign of this shift is correct!
            _flexure = 0. if self.wavetilts.spat_flexure is None else self.wavetilts.spat_flexure
            tilts = tracewave.fit2tilts(rawflat.shape, self.wavetilts['coeffs'][:,:,slit_idx],
                                        self.wavetilts['func2d'], spat_shift=-1*_flexure)
            # Convert the tilt image to an image with the spectral pixel index
            spec_coo = tilts * (nspec-1)

            # Only include the trimmed set of pixels in the flat-field
            # fit along the spectral direction.
            spec_gpm = onslit_trimmed & gpm_log  # & (rawflat < nonlinear_counts)
            spec_nfit = np.sum(spec_gpm)
            spec_ntot = np.sum(onslit_init)
            msgs.info('Spectral fit of flatfield for {0}/{1} '.format(spec_nfit, spec_ntot)
                      + ' pixels in the slit.')
            # Set this to a parameter?
            if spec_nfit/spec_ntot < 0.5:
                # TODO: Shouldn't this raise an exception or continue to the next slit instead?
                msgs.warn('Spectral fit includes only {:.1f}'.format(100*spec_nfit/spec_ntot)
                          + '% of the pixels on this slit.' + msgs.newline()
                          + '          Either the slit has many bad pixels or the number of '
                            'trimmed pixels is too large.')

            # Sort the pixels by their spectral coordinate.
            # TODO: Include ivar and sorted gpm in outputs?
            spec_gpm, spec_srt, spec_coo_data, spec_flat_data \
                    = flat.sorted_flat_data(flat_log, spec_coo, gpm=spec_gpm)
            # NOTE: By default np.argsort sorts the data over the last
            # axis. Just to avoid the possibility (however unlikely) of
            # spec_coo[spec_gpm] returning an array, all the arrays are
            # explicitly flattened.
            spec_ivar_data = ivar_log[spec_gpm].ravel()[spec_srt]
            spec_gpm_data = gpm_log[spec_gpm].ravel()[spec_srt]

            # Rejection threshold for spectral fit in log(image)
            # TODO: Make this a parameter?
            logrej = 0.5

            # Fit the spectral direction of the blaze.
            # TODO: Figure out how to deal with the fits going crazy at
            #  the edges of the chip in spec direction
            # TODO: Can we add defaults to bspline_profile so that we
            #  don't have to instantiate invvar and profile_basis
            spec_bspl, spec_gpm_fit, spec_flat_fit, _, exit_status \
                    = utils.bspline_profile(spec_coo_data, spec_flat_data, spec_ivar_data,
                                            np.ones_like(spec_coo_data), ingpm=spec_gpm_data,
                                            nord=4, upper=logrej, lower=logrej,
                                            kwargs_bspline={'bkspace': spec_samp_fine},
                                            kwargs_reject={'groupbadpix': True, 'maxrej': 5})

            if exit_status > 1:
                # TODO -- MAKE A FUNCTION
                msgs.warn('Flat-field spectral response bspline fit failed!  Not flat-fielding '
                          'slit {0} and continuing!'.format(slit_spat))
                self.slits.mask[slit_idx] = self.slits.bitmask.turn_on(self.slits.mask[slit_idx], 'BADFLATCALIB')
                self.list_of_spat_bsplines.append(bspline.bspline(None))
                continue

            # Debugging/checking spectral fit
            if debug:
                utils.bspline_qa(spec_coo_data, spec_flat_data, spec_bspl, spec_gpm_fit,
                                 spec_flat_fit, xlabel='Spectral Pixel', ylabel='log(flat counts)',
                                 title='Spectral Fit for slit={:d}'.format(slit_spat))

            if sticky:
                # Add rejected pixels to gpm
                gpm[spec_gpm] = (spec_gpm_fit & spec_gpm_data)[np.argsort(spec_srt)]

            # Construct the model of the flat-field spectral shape
            # including padding on either side of the slit.
            spec_model[...] = 1.
            spec_model[onslit_padded] = np.exp(spec_bspl.value(spec_coo[onslit_padded])[0])
            # ----------------------------------------------------------

            # ----------------------------------------------------------
            # To fit the spatial response, first normalize out the
            # spectral response, and then collapse the slit spectrally.

            # Normalize out the spectral shape of the flat
            norm_spec[...] = 1.
            norm_spec[onslit_padded] = rawflat[onslit_padded] \
                                            / np.fmax(spec_model[onslit_padded],1.0)

            # Find pixels fot fit in the spatial direction:
            #   - Fit pixels in the padded slit that haven't been masked
            #     by the BPM
            spat_gpm = onslit_padded & gpm #& (rawflat < nonlinear_counts)
            #   - Fit pixels with non-zero flux and less than 70% above
            #     the average spectral profile.
            spat_gpm &= (norm_spec > 0.0) & (norm_spec < 1.7)
            #   - Determine maximum counts in median filtered flat
            #     spectrum model.
            spec_interp = interpolate.interp1d(spec_coo_data, spec_flat_fit, kind='linear',
                                               assume_sorted=True, bounds_error=False,
                                               fill_value=-np.inf)
            spec_sm = utils.fast_running_median(np.exp(spec_interp(np.arange(nspec))),
                                                np.fmax(np.ceil(0.10*nspec).astype(int),10))
            #   - Only fit pixels with at least values > 10% of this maximum and no less than 1.
            spat_gpm &= (spec_model > 0.1*np.amax(spec_sm)) & (spec_model > 1.0)

            # Report
            spat_nfit = np.sum(spat_gpm)
            spat_ntot = np.sum(onslit_padded)
            msgs.info('Spatial fit of flatfield for {0}/{1} '.format(spat_nfit, spat_ntot)
                      + ' pixels in the slit.')
            if spat_nfit/spat_ntot < 0.5:
                # TODO: Shouldn't this raise an exception or continue to the next slit instead?
                msgs.warn('Spatial fit includes only {:.1f}'.format(100*spat_nfit/spat_ntot)
                          + '% of the pixels on this slit.' + msgs.newline()
                          + '          Either the slit has many bad pixels, the model of the '
                          'spectral shape is poor, or the illumination profile is very irregular.')

            # First fit -- With initial slits
            exit_status, spat_coo_data,  spat_flat_data, spat_bspl, spat_gpm_fit, \
                spat_flat_fit, spat_flat_data_raw \
                        = self.spatial_fit(norm_spec, spat_coo_init, median_slit_widths[slit_idx],
                                           spat_gpm, gpm, debug=debug)

            if tweak_slits:
                # TODO: Should the tweak be based on the bspline fit?
                # TODO: Will this break if
                left_thresh, left_shift, self.slits.left_tweak[:,slit_idx], right_thresh, \
                    right_shift, self.slits.right_tweak[:,slit_idx] \
                        = flat.tweak_slit_edges(self.slits.left_init[:,slit_idx],
                                                self.slits.right_init[:,slit_idx],
                                                spat_coo_data, spat_flat_data,
                                                thresh=tweak_slits_thresh,
                                                maxfrac=tweak_slits_maxfrac, debug=debug)
                # TODO: Because the padding doesn't consider adjacent
                #  slits, calling slit_img for individual slits can be
                #  different from the result when you construct the
                #  image for all slits. Fix this...

                # Update the onslit mask
                _slitid_img = self.slits.slit_img(slitidx=slit_idx, initial=False)
                onslit_tweak = _slitid_img == slit_spat
                spat_coo_tweak = self.slits.spatial_coordinate_image(slitidx=slit_idx,
                                                               slitid_img=_slitid_img)

                # Construct the empirical illumination profile
                # TODO This is extremely inefficient, because we only need to re-fit the illumflat, but
                #  spatial_fit does both the reconstruction of the illumination function and the bspline fitting.
                #  Only the b-spline fitting needs be reddone with the new tweaked spatial coordinates, so that would
                #  save a ton of runtime. It is not a trivial change becauase the coords are sorted, etc.
                exit_status, spat_coo_data, spat_flat_data, spat_bspl, spat_gpm_fit, \
                    spat_flat_fit, spat_flat_data_raw = self.spatial_fit(
                    norm_spec, spat_coo_tweak, median_slit_widths[slit_idx], spat_gpm, gpm, debug=False)

                spat_coo_final = spat_coo_tweak
            else:
                _slitid_img = slitid_img_init
                spat_coo_final = spat_coo_init
                onslit_tweak = onslit_init

            # Add an approximate pixel axis at the top
            if debug:
                # TODO: Move this into a qa plot that gets saved
                ax = utils.bspline_qa(spat_coo_data, spat_flat_data, spat_bspl, spat_gpm_fit,
                                      spat_flat_fit, show=False)
                ax.scatter(spat_coo_data, spat_flat_data_raw, marker='.', s=1, zorder=0, color='k',
                           label='raw data')
                # Force the center of the slit to be at the center of the plot for the hline
                ax.set_xlim(-0.1,1.1)
                ax.axvline(0.0, color='lightgreen', linestyle=':', linewidth=2.0,
                           label='original left edge', zorder=8)
                ax.axvline(1.0, color='red', linestyle=':', linewidth=2.0,
                           label='original right edge', zorder=8)
                if tweak_slits and left_shift > 0:
                    label = 'threshold = {:5.2f}'.format(tweak_slits_thresh) \
                                + ' % of max of left illumprofile'
                    ax.axhline(left_thresh, xmax=0.5, color='lightgreen', linewidth=3.0,
                               label=label, zorder=10)
                    ax.axvline(left_shift, color='lightgreen', linestyle='--', linewidth=3.0,
                               label='tweaked left edge', zorder=11)
                if tweak_slits and right_shift > 0:
                    label = 'threshold = {:5.2f}'.format(tweak_slits_thresh) \
                                + ' % of max of right illumprofile'
                    ax.axhline(right_thresh, xmin=0.5, color='red', linewidth=3.0, label=label,
                               zorder=10)
                    ax.axvline(1-right_shift, color='red', linestyle='--', linewidth=3.0,
                               label='tweaked right edge', zorder=20)
                ax.legend()
                ax.set_xlabel('Normalized Slit Position')
                ax.set_ylabel('Normflat Spatial Profile')
                ax.set_title('Illumination Function Fit for slit={:d}'.format(slit_spat))
                plt.show()

            # ----------------------------------------------------------
            # Construct the illumination profile with the tweaked edges
            # of the slit
            if exit_status <= 1:
                # TODO -- JFH -- Check this is ok for flexure!!
                self.msillumflat[onslit_tweak] = spat_bspl.value(spat_coo_final[onslit_tweak])[0]
                self.list_of_spat_bsplines.append(spat_bspl)
            else:
                # Save the nada
                msgs.warn('Slit illumination profile bspline fit failed!  Spatial profile not '
                          'included in flat-field model for slit {0}!'.format(slit_spat))
                self.slits.mask[slit_idx] = self.slits.bitmask.turn_on(self.slits.mask[slit_idx], 'BADFLATCALIB')
                self.list_of_spat_bsplines.append(bspline.bspline(None))
                continue

            # ----------------------------------------------------------
            # Fit the 2D residuals of the 1D spectral and spatial fits.
            msgs.info('Performing 2D illumination + scattered light flat field fit')

            # Construct the spectrally and spatially normalized flat
            norm_spec_spat[...] = 1.
            norm_spec_spat[onslit_tweak] = rawflat[onslit_tweak] / np.fmax(spec_model[onslit_tweak], 1.0) \
                                                    / np.fmax(self.msillumflat[onslit_tweak], 0.01)

            # Sort the pixels by their spectral coordinate. The mask
            # uses the nominal padding defined by the slits object.
            twod_gpm, twod_srt, twod_spec_coo_data, twod_flat_data \
                    = flat.sorted_flat_data(norm_spec_spat, spec_coo, gpm=onslit_tweak)
            # Also apply the sorting to the spatial coordinates
            twod_spat_coo_data = spat_coo_final[twod_gpm].ravel()[twod_srt]
            # TODO: Reset back to origin gpm if sticky is true?
            twod_gpm_data = gpm[twod_gpm].ravel()[twod_srt]
            # Only fit data with less than 30% variations
            # TODO: Make 30% a parameter?
            twod_gpm_data &= np.absolute(twod_flat_data - 1) < 0.3
            # Here we ignore the formal photon counting errors and
            # simply assume that a typical error per pixel. This guess
            # is somewhat aribtrary. We then set the rejection
            # threshold with sigrej_twod
            # TODO: Make twod_sig and twod_sigrej parameters?
            twod_sig = 0.01
            twod_ivar_data = twod_gpm_data.astype(float)/(twod_sig**2)
            twod_sigrej = 4.0

            poly_basis = basis.fpoly(2.0*twod_spat_coo_data - 1.0, npoly)

            # Perform the full 2d fit
            twod_bspl, twod_gpm_fit, twod_flat_fit, _ , exit_status \
                    = utils.bspline_profile(twod_spec_coo_data, twod_flat_data, twod_ivar_data,
                                            poly_basis, ingpm=twod_gpm_data, nord=4,
                                            upper=twod_sigrej, lower=twod_sigrej,
                                            kwargs_bspline={'bkspace': spec_samp_coarse},
                                            kwargs_reject={'groupbadpix': True, 'maxrej': 10})
            if debug:
                # TODO: Make a plot that shows the residuals in the 2D
                # image
                resid = twod_flat_data - twod_flat_fit
                goodpix = twod_gpm_fit & twod_gpm_data
                badpix = np.invert(twod_gpm_fit) & twod_gpm_data

                plt.clf()
                ax = plt.gca()
                ax.plot(twod_spec_coo_data[goodpix], resid[goodpix], color='k', marker='o',
                        markersize=0.2, mfc='k', fillstyle='full', linestyle='None',
                        label='good points')
                ax.plot(twod_spec_coo_data[badpix], resid[badpix], color='red', marker='+',
                        markersize=0.5, mfc='red', fillstyle='full', linestyle='None',
                        label='masked')
                ax.axhline(twod_sigrej*twod_sig, color='lawngreen', linestyle='--',
                           label='rejection thresholds', zorder=10, linewidth=2.0)
                ax.axhline(-twod_sigrej*twod_sig, color='lawngreen', linestyle='--', zorder=10,
                           linewidth=2.0)
#                ax.set_ylim(-0.05, 0.05)
                ax.legend()
                ax.set_xlabel('Spectral Pixel')
                ax.set_ylabel('Residuals from pixelflat 2-d fit')
                ax.set_title('Spectral Residuals for slit={:d}'.format(slit_spat))
                plt.show()

                plt.clf()
                ax = plt.gca()
                ax.plot(twod_spat_coo_data[goodpix], resid[goodpix], color='k', marker='o',
                        markersize=0.2, mfc='k', fillstyle='full', linestyle='None',
                        label='good points')
                ax.plot(twod_spat_coo_data[badpix], resid[badpix], color='red', marker='+',
                        markersize=0.5, mfc='red', fillstyle='full', linestyle='None',
                        label='masked')
                ax.axhline(twod_sigrej*twod_sig, color='lawngreen', linestyle='--',
                           label='rejection thresholds', zorder=10, linewidth=2.0)
                ax.axhline(-twod_sigrej*twod_sig, color='lawngreen', linestyle='--', zorder=10,
                           linewidth=2.0)
#                ax.set_ylim((-0.05, 0.05))
#                ax.set_xlim(-0.02, 1.02)
                ax.legend()
                ax.set_xlabel('Normalized Slit Position')
                ax.set_ylabel('Residuals from pixelflat 2-d fit')
                ax.set_title('Spatial Residuals for slit={:d}'.format(slit_spat))
                plt.show()

            # Save the 2D residual model
            twod_model[...] = 1.
            if exit_status > 1:
                msgs.warn('Two-dimensional fit to flat-field data failed!  No higher order '
                          'flat-field corrections included in model of slit {0}!'.format(slit_spat))
            else:
                twod_model[twod_gpm] = twod_flat_fit[np.argsort(twod_srt)]

            # Construct the full flat-field model
            # TODO: Why is the 0.05 here for the illumflat compared to the 0.01 above?
            self.flat_model[onslit_tweak] = twod_model[onslit_tweak] \
                                        * np.fmax(self.msillumflat[onslit_tweak], 0.05) \
                                        * np.fmax(spec_model[onslit_tweak], 1.0)

            # Construct the pixel flat
            #self.mspixelflat[onslit] = rawflat[onslit]/self.flat_model[onslit]
            #self.mspixelflat[onslit_tweak] = 1.
            #trimmed_slitid_img_anew = self.slits.slit_img(pad=-trim, slitidx=slit_idx)
            #onslit_trimmed_anew = trimmed_slitid_img_anew == slit_spat
            self.mspixelflat[onslit_tweak] = rawflat[onslit_tweak]/self.flat_model[onslit_tweak]
            # TODO: Add some code here to treat the edges and places where fits
            #  go bad?

        # Set the pixelflat to 1.0 wherever the flat was nonlinear
        self.mspixelflat[rawflat >= nonlinear_counts] = 1.0
        # Set the pixelflat to 1.0 within trim pixels of all the slit edges
        trimmed_slitid_img_new = self.slits.slit_img(pad=-trim, initial=False)
        tweaked_slitid_img = self.slits.slit_img(initial=False)
        self.mspixelflat[(trimmed_slitid_img_new < 0) & (tweaked_slitid_img > 0)] = 1.0


        # Do not apply pixelflat field corrections that are greater than
        # 100% to avoid creating edge effects, etc.
        self.mspixelflat = np.clip(self.mspixelflat, 0.5, 2.0)