robust_mask_nosticky = msk_nosticky == 1
maskdiff = np.sum(robust_mask_nosticky != robust_mask_new)
coeffdiff = np.any(poly_coeff_nosticky != poly_coeff_new)

print(maskdiff)

plt.figure()
plt.plot(xfit,yfit,'ko',mfc='None',label='Good Points')
#plt.plot(xfit[[0,10,19,49,69]],yfit[[0,10,19,49,69]],'bo',label='Outlier')
plt.plot(xfit[indrand],yfit[indrand],'bo',label='Outlier')
plt.plot(xfit,yreal,'k-',lw=3,label='Real')
plt.plot(xfit[~robust_mask], yfit[~robust_mask], 'ms', markersize=10.0,mfc='None', label='robust_polyfit rejected')
plt.plot(xfit[~robust_mask_new],yfit[~robust_mask_new],'r+', markersize = 20.0, label = 'robust_polyfit_djs rejected')
plt.plot(xfit[~robust_mask_nosticky],yfit[~robust_mask_nosticky],'go', mfc='None',markersize = 30.0, label = 'robust_polyfit_djs rejected')

plt.plot(xvec, utils.func_val(poly_coeff, xvec, 'polynomial'),lw=2,ls='-.', color='m', label='robust polyfit')
plt.plot(xvec, utils.func_val(poly_coeff_new, xvec, 'polynomial'),lw=1,ls='--', color='r',label = 'new robust polyfit')
plt.plot(xvec, utils.func_val(poly_coeff_nosticky, xvec, 'polynomial'),lw=1,ls=':', color='g',label = 'new robust polyfit')

plt.legend()
plt.show()

sys.exit(-1)

### Test some results from the PCA
#xfit = np.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
#yfit = np.array([ 5205.0605,3524.0981,1974.9368,694.22455,-359.67508,-1217.6045,-1898.94,-2371.3154,-2726.7856,-2823.9968 ])
xfit = np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
yfit = np.array([  5.781704 ,  -4.7644916,  -2.8626044,  -2.2049518,  -1.45643  ,-2.9384856,   4.4096513,  22.384567 ,  -6.0114756, -12.337624 ])
norder = 3
xvec = np.linspace(xfit.min(), xfit.max(), num=100)
def reidentify_old(spec, wv_calib_arxiv, lamps, nreid_min, detections=None, cc_thresh=0.8,cc_local_thresh = 0.8,
               line_pix_tol=2.0, nlocal_cc=11, rms_threshold=0.15, nonlinear_counts=1e10,sigdetect = 5.0,
               use_unknowns=True,match_toler=3.0,func='legendre',n_first=2,sigrej_first=3.0,n_final=4, sigrej_final=2.0,
               seed=None, debug_xcorr=False, debug_reid=False):

    """ Determine  a wavelength solution for a set of spectra based on archival wavelength solutions

    Parameters
    ----------
    spec :  float ndarray (nspec, nslits)
       Array of arc spectra for which wavelength solutions are desired.

    wv_calib_arxiv: dict
       Dictionary containing archival wavelength solutions for a collection of slits/orders to be used to reidentify
       lines and  determine the wavelength solution for spec. This dict is a standard format for PypeIt wavelength solutions
       as created by pypeit.core.wavecal.fitting.iterative_fitting

    lamps: list of strings
       The are line lamps that are on or the name of the linelist that should be used. For example for Shane Kast blue
       this would ['CdI','HgI','HeI']. For X-shooter NIR which calibrates of a custom OH sky line list,
       it is the name of the line list, i.e. ['OH_XSHOOTER']


    Optional Parameters
    -------------------
    detections: float ndarray, default = None
       An array containing the pixel centroids of the lines in the arc as computed by the pypeit.core.arc.detect_lines
       code. If this is set to None, the line detection will be run inside the code.

    cc_thresh: float, default = 0.8
       Threshold for the *global* cross-correlation coefficient between an input spectrum and member of the archive required to
       attempt reidentification. Spectra from the archive with a lower cross-correlation are not used for reidentification

    cc_local_thresh: float, default = 0.8
       Threshold for the *local* cross-correlation coefficient, evaluated at each reidentified line,  between an input
       spectrum and the shifted and stretched archive spectrum above which a line must be to be considered a good line for
       reidentification. The local cross-correlation is evaluated at each candidate reidentified line
       (using a window of nlocal_cc), and is then used to score the the reidentified lines to arrive at the final set of
       good reidentifications

    line_pix_tol: float, default = 2.0
       Matching tolerance in pixels for a line reidentification. A good line match must match within this tolerance to the
       the shifted and stretched archive spectrum, and the archive wavelength solution at this match must be within
       line_pix_tol dispersion elements from the line in line list.

    n_local_cc: int, defualt = 11
       Size of pixel window used for local cross-correlation computation for each arc line. If not an odd number one will
       be added to it to make it odd.

    rms_threshold: float, default = 0.15
       Minimum rms for considering a wavelength solution to be an acceptable good fit. Slits/orders with a larger RMS
       than this are flagged as bad slits

    nonlinear_counts: float, default = 1e10
       Arc lines above this saturation threshold are not used in wavelength solution fits because they cannot be accurately
       centroided

    sigdetect: float, default 5.0
       Sigma threshold above fluctuations for arc-line detection. Arcs are continuum subtracted and the fluctuations are
       computed after continuum subtraction.

    use_unknowns : bool, default = True
       If True, arc lines that are known to be present in the spectra, but have not been attributed to an element+ion,
       will be included in the fit.

    match_toler: float, default = 3.0
       Matching tolerance when searching for new lines. This is the difference in pixels between the wavlength assigned to
       an arc line by an iteration of the wavelength solution to the wavelength in the line list.

    func: str, default = 'legendre'
       Name of function used for the wavelength solution

    n_first: int, default = 2
       Order of first guess to the wavelength solution.

    sigrej_first: float, default = 2.0
       Number of sigma for rejection for the first guess to the wavelength solution.

    n_final: int, default = 4
       Order of the final wavelength solution fit

    sigrej_final: float, default = 3.0
       Number of sigma for rejection for the final fit to the wavelength solution.

    seed: int or np.random.RandomState, optional, default = None
       Seed for scipy.optimize.differential_evolution optimizer. If not specified, the calculation will be seeded
       in a deterministic way from the input arc spectrum spec.

    debug_xcorr: bool, default = False
       Show plots useful for debugging the cross-correlation used for shift/stretch computation

    debug_reid: bool, default = False
       Show plots useful for debugging the line reidentification

    Returns
    -------
    (wv_calib, patt_dict, bad_slits)

    wv_calib: dict
       Wavelength solution for the input arc spectra spec. These are stored in standard pypeit format, i.e.
       each index of spec[:,slit] corresponds to a key in the wv_calib dictionary wv_calib[str(slit)] which yields
       the final_fit dictionary for this slit

    patt_dict: dict
       Arc lines pattern dictionary with some information about the IDs as well as the cross-correlation values

    bad_slits: ndarray, int
       Numpy array with the indices of the bad slits. These are the indices in the input arc spectrum array spec[:,islit]


    Revision History
    ----------------
    November 2018 by J.F. Hennawi. Based on an initial version of this code written by Ryan Cooke.
    """

    # Determine the seed for scipy.optimize.differential_evolution optimizer
    if seed is None:
        # If no seed is specified just take the sum of all the elements and round that to an integer
        seed = np.fmin(int(np.sum(spec)),2**32-1)

    random_state = np.random.RandomState(seed = seed)


    nlocal_cc_odd = nlocal_cc + 1 if nlocal_cc % 2 == 0 else nlocal_cc
    window = 1.0/nlocal_cc_odd* np.ones(nlocal_cc_odd)

    # Generate the line list
    line_lists = waveio.load_line_lists(lamps)
    unknwns = waveio.load_unknown_list(lamps)
    if use_unknowns:
        tot_list = table.vstack([line_lists, unknwns])
    else:
        tot_list = line_lists
    # Generate the final linelist and sort
    wvdata = np.array(tot_list['wave'].data)  # Removes mask if any
    wvdata.sort()

    nspec, nslits = spec.shape
    narxiv = len(wv_calib_arxiv)
    nspec_arxiv = wv_calib_arxiv['0']['spec'].size
    if nspec_arxiv != nspec:
        msgs.error('Different spectral binning is not supported yet but it will be soon')

    # If the detections were not passed in find the lines in each spectrum
    if detections is None:
        detections = {}
        for islit in range(nslits):
            tcent, ecent, cut_tcent, icut = wvutils.arc_lines_from_spec(spec[:, islit], sigdetect=sigdetect,nonlinear_counts=nonlinear_counts)
            detections[str(islit)] = [tcent[icut].copy(), ecent[icut].copy()]
    else:
        if len(detections) != nslits:
            msgs.error('Detections must be a dictionary with nslit elements')

    # For convenience pull out all the spectra from the wv_calib_arxiv archive
    spec_arxiv = np.zeros((nspec, narxiv))
    wave_soln_arxiv = np.zeros((nspec, narxiv))
    wvc_arxiv = np.zeros(narxiv, dtype=float)
    disp_arxiv = np.zeros(narxiv, dtype=float)
    xrng = np.arange(nspec_arxiv)
    for iarxiv in range(narxiv):
        spec_arxiv[:,iarxiv] = wv_calib_arxiv[str(iarxiv)]['spec']
        fitc = wv_calib_arxiv[str(iarxiv)]['fitc']
        fitfunc = wv_calib_arxiv[str(iarxiv)]['function']
        fmin, fmax = wv_calib_arxiv[str(iarxiv)]['fmin'],wv_calib_arxiv[str(iarxiv)]['fmax']
        wave_soln_arxiv[:,iarxiv] = utils.func_val(fitc, xrng, fitfunc, minv=fmin, maxv=fmax)
        wvc_arxiv[iarxiv] = wave_soln_arxiv[nspec_arxiv//2, iarxiv]
        disp_arxiv[iarxiv] = np.median(wave_soln_arxiv[:,iarxiv] - np.roll(wave_soln_arxiv[:,iarxiv], 1))

    wv_calib = {}
    patt_dict = {}
    bad_slits = np.array([], dtype=np.int)

    marker_tuple = ('o','v','<','>','8','s','p','P','*','X','D','d','x')
    color_tuple = ('black','green','red','cyan','magenta','blue','darkorange','yellow','dodgerblue','purple','lightgreen','cornflowerblue')
    marker = itertools.cycle(marker_tuple)
    colors = itertools.cycle(color_tuple)

    # Loop over the slits in the spectrum and cross-correlate each with each arxiv spectrum to identify lines
    for islit in range(nslits):
        slit_det = detections[str(islit)][0]
        line_indx = np.array([], dtype=np.int)
        det_indx = np.array([], dtype=np.int)
        line_cc = np.array([], dtype=float)
        line_iarxiv = np.array([], dtype=np.int)
        wcen = np.zeros(narxiv)
        disp = np.zeros(narxiv)
        shift_vec = np.zeros(narxiv)
        stretch_vec = np.zeros(narxiv)
        ccorr_vec = np.zeros(narxiv)
        for iarxiv in range(narxiv):
            msgs.info('Cross-correlating slit # {:d}'.format(islit + 1) + ' with arxiv slit # {:d}'.format(iarxiv + 1))
            # Match the peaks between the two spectra. This code attempts to compute the stretch if cc > cc_thresh
            success, shift_vec[iarxiv], stretch_vec[iarxiv], ccorr_vec[iarxiv], _, _ = \
                wvutils.xcorr_shift_stretch(spec[:, islit], spec_arxiv[:, iarxiv], cc_thresh=cc_thresh, seed = random_state,
                                            debug=debug_xcorr)
            # If cc < cc_thresh or if this optimization failed, don't reidentify from this arxiv spectrum
            if success != 1:
                continue
            # Estimate wcen and disp for this slit based on its shift/stretch relative to the archive slit
            disp[iarxiv] = disp_arxiv[iarxiv] / stretch_vec[iarxiv]
            wcen[iarxiv] = wvc_arxiv[iarxiv] - shift_vec[iarxiv]*disp[iarxiv]
            # For each peak in the arxiv spectrum, identify the corresponding peaks in the input islit spectrum. Do this by
            # transforming these arxiv slit line pixel locations into the (shifted and stretched) input islit spectrum frame
            arxiv_det = wv_calib_arxiv[str(iarxiv)]['xfit']
            arxiv_det_ss = arxiv_det*stretch_vec[iarxiv] + shift_vec[iarxiv]
            spec_arxiv_ss = wvutils.shift_and_stretch(spec_arxiv[:, iarxiv], shift_vec[iarxiv], stretch_vec[iarxiv])

            if debug_xcorr:
                plt.figure(figsize=(14, 6))
                tampl_slit = np.interp(slit_det, xrng, spec[:, islit])
                plt.plot(xrng, spec[:, islit], color='red', drawstyle='steps-mid', label='input arc',linewidth=1.0, zorder=10)
                plt.plot(slit_det, tampl_slit, 'r.', markersize=10.0, label='input arc lines', zorder=10)
                tampl_arxiv = np.interp(arxiv_det, xrng, spec_arxiv[:, iarxiv])
                plt.plot(xrng, spec_arxiv[:, iarxiv], color='black', drawstyle='steps-mid', linestyle=':',
                         label='arxiv arc', linewidth=0.5)
                plt.plot(arxiv_det, tampl_arxiv, 'k+', markersize=8.0, label='arxiv arc lines')
                # tampl_ss = np.interp(gsdet_ss, xrng, gdarc_ss)
                for iline in range(arxiv_det_ss.size):
                    plt.plot([arxiv_det[iline], arxiv_det_ss[iline]], [tampl_arxiv[iline], tampl_arxiv[iline]],
                             color='cornflowerblue', linewidth=1.0)
                plt.plot(xrng, spec_arxiv_ss, color='black', drawstyle='steps-mid', label='arxiv arc shift/stretch',linewidth=1.0)
                plt.plot(arxiv_det_ss, tampl_arxiv, 'k.', markersize=10.0, label='predicted arxiv arc lines')
                plt.title(
                    'Cross-correlation of input slit # {:d}'.format(islit + 1) + ' and arxiv slit # {:d}'.format(iarxiv + 1) +
                    ': ccor = {:5.3f}'.format(ccorr_vec[iarxiv]) +
                    ', shift = {:6.1f}'.format(shift_vec[iarxiv]) +
                    ', stretch = {:5.4f}'.format(stretch_vec[iarxiv]) +
                    ', wv_cen = {:7.1f}'.format(wcen[iarxiv]) +
                    ', disp = {:5.3f}'.format(disp[iarxiv]))
                plt.ylim(1.2*spec[:, islit].min(), 1.5 *spec[:, islit].max())
                plt.legend()
                plt.show()

            # Calculate wavelengths for all of the gsdet detections
            wvval_arxiv= utils.func_val(wv_calib_arxiv[str(iarxiv)]['fitc'], arxiv_det,wv_calib_arxiv[str(iarxiv)]['function'],
                                        minv=wv_calib_arxiv[str(iarxiv)]['fmin'], maxv=wv_calib_arxiv[str(iarxiv)]['fmax'])
            # Compute a "local" zero lag correlation of the slit spectrum and the shifted and stretch arxiv spectrum over a
            # a nlocal_cc_odd long segment of spectrum. We will then uses spectral similarity as a further criteria to
            # decide which lines are good matches
            prod_smooth = scipy.ndimage.filters.convolve1d(spec[:, islit]*spec_arxiv_ss, window)
            spec2_smooth = scipy.ndimage.filters.convolve1d(spec[:, islit]**2, window)
            arxiv2_smooth = scipy.ndimage.filters.convolve1d(spec_arxiv_ss**2, window)
            denom = np.sqrt(spec2_smooth*arxiv2_smooth)
            corr_local = np.zeros_like(denom)
            corr_local[denom > 0] = prod_smooth[denom > 0]/denom[denom > 0]
            corr_local[denom == 0.0] = -1.0

            # Loop over the current slit line pixel detections and find the nearest arxiv spectrum line
            for iline in range(slit_det.size):
                # match to pixel in shifted/stretch arxiv spectrum
                pdiff = np.abs(slit_det[iline] - arxiv_det_ss)
                bstpx = np.argmin(pdiff)
                # If a match is found within 2 pixels, consider this a successful match
                if pdiff[bstpx] < line_pix_tol:
                    # Using the arxiv arc wavelength solution, search for the nearest line in the line list
                    bstwv = np.abs(wvdata - wvval_arxiv[bstpx])
                    # This is a good wavelength match if it is within line_pix_tol disperion elements
                    if bstwv[np.argmin(bstwv)] < line_pix_tol*disp_arxiv[iarxiv]:
                        line_indx = np.append(line_indx, np.argmin(bstwv))  # index in the line list array wvdata of this match
                        det_indx = np.append(det_indx, iline)             # index of this line in the detected line array slit_det
                        line_cc = np.append(line_cc,np.interp(slit_det[iline],xrng,corr_local)) # local cross-correlation at this match
                        line_iarxiv = np.append(line_iarxiv,iarxiv)

        narxiv_used = np.sum(wcen != 0.0)
        if (narxiv_used == 0) or (len(np.unique(line_indx)) < 3):
            wv_calib[str(islit)] = {}
            patt_dict[str(islit)] = {}
            bad_slits = np.append(bad_slits,islit)
            continue

        if debug_reid:
            plt.figure(figsize=(14, 6))
            # Plot a summary of the local x-correlation values for each line on each slit
            for iarxiv in range(narxiv):
                # Only plot those that we actually tried to reidentify (i.e. above cc_thresh)
                if wcen[iarxiv] != 0.0:
                    this_iarxiv = line_iarxiv == iarxiv
                    plt.plot(wvdata[line_indx[this_iarxiv]],line_cc[this_iarxiv],marker=next(marker),color=next(colors),
                             linestyle='',markersize=5.0,label='arxiv slit={:d}'.format(iarxiv))

            plt.hlines(cc_local_thresh, wvdata[line_indx].min(), wvdata[line_indx].max(), color='red', linestyle='--',label='Local xcorr threshhold')
            plt.title('slit={:d}'.format(islit + 1) + ': Local x-correlation for reidentified lines from narxiv_used={:d}'.format(narxiv_used) +
                      ' arxiv slits. Requirement: nreid_min={:d}'.format(nreid_min) + ' matches > threshold')
            plt.xlabel('wavelength from line list')
            plt.ylabel('Local x-correlation coefficient')
            #plt.ylim((0.0, 1.2))
            plt.legend()
            plt.show()

        # Finalize the best guess of each line
        # Initialise the patterns dictionary, min_nsig not used anywhere
        patt_dict_slit = dict(acceptable=False, nmatch=0, ibest=-1, bwv=0., min_nsig=sigdetect,mask=np.zeros(slit_det.size, dtype=np.bool))
        patt_dict_slit['sign'] = 1 # This is not used anywhere
        patt_dict_slit['bwv'] = np.median(wcen[wcen != 0.0])
        patt_dict_slit['bdisp'] = np.median(disp[disp != 0.0])
        patterns.solve_xcorr(slit_det, wvdata, det_indx, line_indx, line_cc, patt_dict=patt_dict_slit,nreid_min=nreid_min,
                             cc_local_thresh=cc_local_thresh)

        if debug_reid:
            tmp_list = table.vstack([line_lists, unknwns])
            qa.match_qa(spec[:, islit], slit_det, tmp_list, patt_dict_slit['IDs'], patt_dict_slit['scores'])

        # Use only the perfect IDs
        iperfect = np.array(patt_dict_slit['scores']) != 'Perfect'
        patt_dict_slit['mask'][iperfect] = False
        patt_dict_slit['nmatch'] = np.sum(patt_dict_slit['mask'])
        if patt_dict_slit['nmatch'] < 3:
            patt_dict_slit['acceptable'] = False

        # Check if an acceptable reidentification solution was found
        if not patt_dict_slit['acceptable']:
            wv_calib[str(islit)] = {}
            patt_dict[str(islit)] = copy.deepcopy(patt_dict_slit)
            bad_slits = np.append(bad_slits,islit)
            continue
        # Perform the fit
        final_fit = fitting.fit_slit(spec[:,islit], patt_dict_slit, slit_det, line_lists, match_toler=match_toler,
                             func=func, n_first=n_first,sigrej_first=sigrej_first,n_final=n_final,
                             sigrej_final=sigrej_final)

        # Did the fit succeed?
        if final_fit is None:
            # This pattern wasn't good enough
            wv_calib[str(islit)] = {}
            patt_dict[str(islit)] = copy.deepcopy(patt_dict_slit)
            bad_slits = np.append(bad_slits, islit)
            continue
        # Is the RMS below the threshold?
        if final_fit['rms'] > rms_threshold:
            msgs.warn('---------------------------------------------------' + msgs.newline() +
                      'Reidentify report for slit {0:d}/{1:d}:'.format(islit + 1, nslits) + msgs.newline() +
                      '  Poor RMS ({0:.3f})! Need to add additional spectra to arxiv to improve fits'.format(final_fit['rms']) + msgs.newline() +
                      '---------------------------------------------------')
            bad_slits = np.append(bad_slits, islit)
            # Note this result in new_bad_slits, but store the solution since this might be the best possible

        # Add the patt_dict and wv_calib to the output dicts
        patt_dict[str(islit)] = copy.deepcopy(patt_dict_slit)
        wv_calib[str(islit)] = copy.deepcopy(final_fit)
        if debug_reid:
            qa.arc_fit_qa(wv_calib[str(islit)])
            #yplt = utils.func_val(final_fit['fitc'], xrng, final_fit['function'], minv=final_fit['fmin'], maxv=final_fit['fmax'])
            #plt.plot(final_fit['xfit'], final_fit['yfit'], 'bx')
            #plt.plot(xrng, yplt, 'r-')
            #plt.show()

    return wv_calib, patt_dict, bad_slits
def pca_trace(xcen, usepca = None, npca = 2, npoly_cen = 3, debug=True):

    nspec = xcen.shape[0]
    norders = xcen.shape[1]
    if usepca is None:
        usepca = np.zeros(norders,dtype=bool)

    # use_order = True orders used to predict the usepca = True bad orders
    use_order = np.invert(usepca)
    ngood = np.sum(use_order)
    if ngood < npca:
        msgs.warn('Not enough good traces for a PCA fit: ngood = {:d}'.format(ngood) + ' is < npca = {:d}'.format(npca))
        msgs.warn('Using the input trace for now')
        return xcen

    pca = PCA(n_components=npca)
    xcen_use = (xcen[:,use_order] - np.mean(xcen[:,use_order],0)).T
    pca_coeffs_use = pca.fit_transform(xcen_use)
    pca_vectors = pca.components_

    # Fit first pca dimension (with largest variance) with a higher order npoly depending on number of good orders.
    # Fit all higher dimensions (with lower variance) with a line
    npoly = int(np.fmin(np.fmax(np.floor(3.3*ngood/norders),1.0),3.0))
    npoly_vec = np.full(npca, npoly)
    order_vec = np.arange(norders,dtype=float)
    # pca_coeffs = np.zeros((norders, npca))
    pca_coeffs_new = np.zeros((norders, npca))
    # Now loop over the dimensionality of the compression and perform a polynomial fit to
    for idim in range(npca):
        # ToDO robust_polyfit is garbage remove it entirely from PypeIT!
        xfit = order_vec[use_order]
        yfit = pca_coeffs_use[:,idim]
        norder = npoly_vec[idim]

        # msk, poly_coeff = utils.robust_polyfit(xfit, yfit, norder, sigma = 3.0, function='polynomial')
        # pca_coeffs[:,idim] = utils.func_val(poly_coeff, order_vec, 'polynomial')

        # TESTING traceset fitting
        xtemp = xfit.reshape(1, xfit.size)
        ytemp = yfit.reshape(1, yfit.size)
        tset = pydl.xy2traceset(xtemp, ytemp, ncoeff=norder,func='polynomial')
        #tset_yfit = tset.yfit.reshape(tset.yfit.shape[1])

        ## Test new robust fitting with djs_reject
        msk_new, poly_coeff_new = utils.robust_polyfit_djs(xfit, yfit, norder, \
                                                   function='polynomial', minv=None, maxv=None, bspline_par=None, \
                                                   guesses=None, maxiter=10, inmask=None, sigma=None, invvar=None, \
                                                   lower=5, upper=5, maxdev=None, maxrej=None, groupdim=None,
                                                   groupsize=None, \
                                                   groupbadpix=False, grow=0, sticky=False)
        pca_coeffs_new[:,idim] = utils.func_val(poly_coeff_new, order_vec, 'polynomial')

        if debug:
            # Evaluate the fit
            xvec = np.linspace(order_vec.min(),order_vec.max(),num=100)
            (_,tset_fit) = tset.xy(xpos=xvec.reshape(1,xvec.size))
            yfit_tset = tset_fit[0,:]
            #robust_mask = msk == 0
            robust_mask_new = msk_new == 1
            tset_mask = tset.outmask[0,:]
            plt.plot(xfit, yfit, 'ko', mfc='None', markersize=8.0, label='pca coeff')
            #plt.plot(xfit[~robust_mask], yfit[~robust_mask], 'ms', mfc='None', markersize=10.0,label='robust_polyfit rejected')
            #plt.plot(xfit[~robust_mask_new], yfit[~robust_mask_new], 'r+', markersize=20.0,label='robust_polyfit_djs rejected')
            plt.plot(xfit[~tset_mask],yfit[~tset_mask], 'bo', markersize = 10.0, label = 'traceset rejected')
            #plt.plot(xvec, utils.func_val(poly_coeff, xvec, 'polynomial'),ls='--', color='m', label='robust polyfit')
            plt.plot(xvec, utils.func_val(poly_coeff_new, xvec, 'polynomial'),ls='-.', color='r', label='new robust polyfit')
            plt.plot(xvec, yfit_tset,ls=':', color='b',label='traceset')
            plt.legend()
            plt.show()

    #ToDo should we be masking the bad orders here and interpolating/extrapolating?
    spat_mean = np.mean(xcen,0)
    msk_spat, poly_coeff_spat = utils.robust_polyfit(order_vec, spat_mean, npoly_cen, sigma = 3.0, function = 'polynomial')
    ibad = np.where(msk_spat == 1)
    spat_mean[ibad] = utils.func_val(poly_coeff_spat,order_vec[ibad],'polynomial')

    #pca_fit = np.outer(np.ones(nspec), spat_mean) + np.outer(pca.mean_,np.ones(norders)) + (np.dot(pca_coeffs, pca_vectors)).T
    pca_fit = np.outer(np.ones(nspec), spat_mean) + np.outer(pca.mean_,np.ones(norders)) + (np.dot(pca_coeffs_new, pca_vectors)).T

    return pca_fit
示例#4
0
def spec_flex_shift(obj_skyspec, arx_skyspec, mxshft=20):
    """ Calculate shift between object sky spectrum and archive sky spectrum

    Args:
        obj_skyspec (:class:`linetools.spectra.xspectrum1d.XSpectrum1d`):
            Spectrum of the sky related to our object
        arx_skyspec (:class:`linetools.spectra.xspectrum1d.XSpectrum1d`):
            Archived sky spectrum
        mxshft (float, optional):
            Maximum allowed shift from flexure;  note there are cases that
            have been known to exceed even 30 pixels..

    Returns:
        dict: Contains flexure info
    """

    # TODO None of these routines should have dependencies on XSpectrum1d!

    # Determine the brightest emission lines
    msgs.warn("If we use Paranal, cut down on wavelength early on")
    arx_amp, arx_amp_cont, arx_cent, arx_wid, _, arx_w, arx_yprep, nsig \
            = arc.detect_lines(arx_skyspec.flux.value)
    obj_amp, obj_amp_cont, obj_cent, obj_wid, _, obj_w, obj_yprep, nsig_obj \
            = arc.detect_lines(obj_skyspec.flux.value)

    # Keep only 5 brightest amplitude lines (xxx_keep is array of
    # indices within arx_w of the 5 brightest)
    arx_keep = np.argsort(arx_amp[arx_w])[-5:]
    obj_keep = np.argsort(obj_amp[obj_w])[-5:]

    # Calculate wavelength (Angstrom per pixel)
    arx_disp = np.append(
        arx_skyspec.wavelength.value[1] - arx_skyspec.wavelength.value[0],
        arx_skyspec.wavelength.value[1:] - arx_skyspec.wavelength.value[:-1])
    obj_disp = np.append(
        obj_skyspec.wavelength.value[1] - obj_skyspec.wavelength.value[0],
        obj_skyspec.wavelength.value[1:] - obj_skyspec.wavelength.value[:-1])

    # Calculate resolution (lambda/delta lambda_FWHM)..maybe don't need
    # this? can just use sigmas
    arx_idx = (arx_cent + 0.5).astype(
        np.int)[arx_w][arx_keep]  # The +0.5 is for rounding
    arx_res = arx_skyspec.wavelength.value[arx_idx]/\
              (arx_disp[arx_idx]*(2*np.sqrt(2*np.log(2)))*arx_wid[arx_w][arx_keep])
    obj_idx = (obj_cent + 0.5).astype(
        np.int)[obj_w][obj_keep]  # The +0.5 is for rounding
    obj_res = obj_skyspec.wavelength.value[obj_idx]/ \
              (obj_disp[obj_idx]*(2*np.sqrt(2*np.log(2)))*obj_wid[obj_w][obj_keep])

    if not np.all(np.isfinite(obj_res)):
        msgs.warn(
            'Failed to measure the resolution of the object spectrum, likely due to error '
            'in the wavelength image.')
        return None
    msgs.info("Resolution of Archive={0} and Observation={1}".format(
        np.median(arx_res), np.median(obj_res)))

    # Determine sigma of gaussian for smoothing
    arx_sig2 = np.power(arx_disp[arx_idx] * arx_wid[arx_w][arx_keep], 2)
    obj_sig2 = np.power(obj_disp[obj_idx] * obj_wid[obj_w][obj_keep], 2)

    arx_med_sig2 = np.median(arx_sig2)
    obj_med_sig2 = np.median(obj_sig2)

    if obj_med_sig2 >= arx_med_sig2:
        smooth_sig = np.sqrt(obj_med_sig2 - arx_med_sig2)  # Ang
        smooth_sig_pix = smooth_sig / np.median(arx_disp[arx_idx])
        arx_skyspec = arx_skyspec.gauss_smooth(smooth_sig_pix * 2 *
                                               np.sqrt(2 * np.log(2)))
    else:
        msgs.warn("Prefer archival sky spectrum to have higher resolution")
        smooth_sig_pix = 0.
        msgs.warn("New Sky has higher resolution than Archive.  Not smoothing")
        #smooth_sig = np.sqrt(arx_med_sig**2-obj_med_sig**2)

    #Determine region of wavelength overlap
    min_wave = max(np.amin(arx_skyspec.wavelength.value),
                   np.amin(obj_skyspec.wavelength.value))
    max_wave = min(np.amax(arx_skyspec.wavelength.value),
                   np.amax(obj_skyspec.wavelength.value))

    #Smooth higher resolution spectrum by smooth_sig (flux is conserved!)
    #    if np.median(obj_res) >= np.median(arx_res):
    #        msgs.warn("New Sky has higher resolution than Archive.  Not smoothing")
    #obj_sky_newflux = ndimage.gaussian_filter(obj_sky.flux, smooth_sig)
    #    else:
    #tmp = ndimage.gaussian_filter(arx_sky.flux, smooth_sig)
    #        arx_skyspec = arx_skyspec.gauss_smooth(smooth_sig_pix*2*np.sqrt(2*np.log(2)))
    #arx_sky.flux = ndimage.gaussian_filter(arx_sky.flux, smooth_sig)

    # Define wavelengths of overlapping spectra
    keep_idx = np.where((obj_skyspec.wavelength.value >= min_wave)
                        & (obj_skyspec.wavelength.value <= max_wave))[0]
    #keep_wave = [i for i in obj_sky.wavelength.value if i>=min_wave if i<=max_wave]

    #Rebin both spectra onto overlapped wavelength range
    if len(keep_idx) <= 50:
        msgs.warn("Not enough overlap between sky spectra")
        return None

    # rebin onto object ALWAYS
    keep_wave = obj_skyspec.wavelength[keep_idx]
    arx_skyspec = arx_skyspec.rebin(keep_wave)
    obj_skyspec = obj_skyspec.rebin(keep_wave)
    # Trim edges (rebinning is junk there)
    arx_skyspec.data['flux'][0, :2] = 0.
    arx_skyspec.data['flux'][0, -2:] = 0.
    obj_skyspec.data['flux'][0, :2] = 0.
    obj_skyspec.data['flux'][0, -2:] = 0.

    # Normalize spectra to unit average sky count
    norm = np.sum(obj_skyspec.flux.value) / obj_skyspec.npix
    obj_skyspec.flux = obj_skyspec.flux / norm
    norm2 = np.sum(arx_skyspec.flux.value) / arx_skyspec.npix
    arx_skyspec.flux = arx_skyspec.flux / norm2
    if norm < 0:
        msgs.warn("Bad normalization of object in flexure algorithm")
        msgs.warn("Will try the median")
        norm = np.median(obj_skyspec.flux.value)
        if norm < 0:
            msgs.warn("Improper sky spectrum for flexure.  Is it too faint??")
            return None
    if norm2 < 0:
        msgs.warn(
            'Bad normalization of archive in flexure. You are probably using wavelengths '
            'well beyond the archive.')
        return None

    # Deal with bad pixels
    msgs.work("Need to mask bad pixels")

    # Deal with underlying continuum
    msgs.work("Consider taking median first [5 pixel]")
    everyn = obj_skyspec.npix // 20
    bspline_par = dict(everyn=everyn)
    mask, ct = utils.robust_polyfit(obj_skyspec.wavelength.value,
                                    obj_skyspec.flux.value,
                                    3,
                                    function='bspline',
                                    sigma=3.,
                                    bspline_par=bspline_par)
    obj_sky_cont = utils.func_val(ct, obj_skyspec.wavelength.value, 'bspline')
    obj_sky_flux = obj_skyspec.flux.value - obj_sky_cont
    mask, ct_arx = utils.robust_polyfit(arx_skyspec.wavelength.value,
                                        arx_skyspec.flux.value,
                                        3,
                                        function='bspline',
                                        sigma=3.,
                                        bspline_par=bspline_par)
    arx_sky_cont = utils.func_val(ct_arx, arx_skyspec.wavelength.value,
                                  'bspline')
    arx_sky_flux = arx_skyspec.flux.value - arx_sky_cont

    # Consider sharpness filtering (e.g. LowRedux)
    msgs.work("Consider taking median first [5 pixel]")

    #Cross correlation of spectra
    #corr = np.correlate(arx_skyspec.flux, obj_skyspec.flux, "same")
    corr = np.correlate(arx_sky_flux, obj_sky_flux, "same")

    #Create array around the max of the correlation function for fitting for subpixel max
    # Restrict to pixels within maxshift of zero lag
    lag0 = corr.size // 2
    #mxshft = settings.argflag['reduce']['flexure']['maxshift']
    max_corr = np.argmax(corr[lag0 - mxshft:lag0 + mxshft]) + lag0 - mxshft
    subpix_grid = np.linspace(max_corr - 3., max_corr + 3., 7)

    #Fit a 2-degree polynomial to peak of correlation function. JFH added this if/else to not crash for bad slits
    if np.any(np.isfinite(corr[subpix_grid.astype(np.int)])):
        fit = utils.func_fit(subpix_grid, corr[subpix_grid.astype(np.int)],
                             'polynomial', 2)
        success = True
        max_fit = -0.5 * fit[1] / fit[2]
    else:
        fit = utils.func_fit(subpix_grid, 0.0 * subpix_grid, 'polynomial', 2)
        success = False
        max_fit = 0.0
        msgs.warn('Flexure compensation failed for one of your objects')

    #Calculate and apply shift in wavelength
    shift = float(max_fit) - lag0
    msgs.info("Flexure correction of {:g} pixels".format(shift))
    #model = (fit[2]*(subpix_grid**2.))+(fit[1]*subpix_grid)+fit[0]

    return dict(polyfit=fit,
                shift=shift,
                subpix=subpix_grid,
                corr=corr[subpix_grid.astype(np.int)],
                sky_spec=obj_skyspec,
                arx_spec=arx_skyspec,
                corr_cen=corr.size / 2,
                smooth=smooth_sig_pix,
                success=success)
narxiv = len(wv_calib_arxiv)
nspec = wv_calib_arxiv['0']['spec'].size
# assignments
spec_arxiv = np.zeros((nspec, narxiv))
for iarxiv in range(narxiv):
    spec_arxiv[:,iarxiv] = wv_calib_arxiv[str(iarxiv)]['spec']

det_arxiv = {}
wave_soln_arxiv = np.zeros((nspec, narxiv))
xrng = np.arange(nspec)
for iarxiv in range(narxiv):
    spec_arxiv[:, iarxiv] = wv_calib_arxiv[str(iarxiv)]['spec']
    fitc = wv_calib_arxiv[str(iarxiv)]['fitc']
    fitfunc = wv_calib_arxiv[str(iarxiv)]['function']
    fmin, fmax = wv_calib_arxiv[str(iarxiv)]['fmin'], wv_calib_arxiv[str(iarxiv)]['fmax']
    wave_soln_arxiv[:, iarxiv] = utils.func_val(fitc, xrng, fitfunc, minv=fmin, maxv=fmax)
    det_arxiv[str(iarxiv)] = wv_calib_arxiv[str(iarxiv)]['xfit']



match_toler = 2.0 #par['match_toler']
n_first = par['n_first']
sigrej_first = par['sigrej_first']
n_final = par['n_final']
sigrej_final = par['sigrej_final']
func = par['func']
nonlinear_counts=par['nonlinear_counts']
sigdetect = par['lowest_nsig']
rms_threshold = par['rms_threshold']
lamps = par['lamps']
示例#6
0
         mfc='None',
         label='robust_polyfit rejected')
plt.plot(xfit[~robust_mask_new],
         yfit[~robust_mask_new],
         'r+',
         markersize=20.0,
         label='robust_polyfit_djs rejected')
plt.plot(xfit[~robust_mask_nosticky],
         yfit[~robust_mask_nosticky],
         'go',
         mfc='None',
         markersize=30.0,
         label='robust_polyfit_djs rejected')

plt.plot(xvec,
         utils.func_val(poly_coeff, xvec, 'polynomial'),
         lw=2,
         ls='-.',
         color='m',
         label='robust polyfit')
plt.plot(xvec,
         utils.func_val(poly_coeff_new, xvec, 'polynomial'),
         lw=1,
         ls='--',
         color='r',
         label='new robust polyfit')
plt.plot(xvec,
         utils.func_val(poly_coeff_nosticky, xvec, 'polynomial'),
         lw=1,
         ls=':',
         color='g',
示例#7
0
def basis(xfit,
          yfit,
          coeff,
          npc,
          pnpc,
          weights=None,
          skipx0=True,
          x0in=None,
          mask=None,
          function='polynomial'):
    nrow = xfit.shape[0]
    ntrace = xfit.shape[1]
    if x0in is None:
        x0in = np.arange(float(ntrace))

    # Mask out some orders if they are bad
    if mask is None or mask.size == 0:
        usetrace = np.arange(ntrace)
        outmask = np.ones((nrow, ntrace))
    else:
        usetrace = np.where(np.in1d(np.arange(ntrace), mask) == False)[0]
        outmask = np.ones((nrow, ntrace))
        outmask[:, mask] = 0.0

    # Do the PCA analysis
    eigc, hidden = get_pc(coeff[1:npc + 1, usetrace], npc)

    modl = func_vander(xfit[:, 0], function, npc)
    eigv = np.dot(modl[:, 1:], eigc)

    med_hidden = np.median(hidden, axis=1)
    med_highorder = med_hidden.copy()
    med_highorder[0] = 0

    high_order_matrix = med_highorder.T[np.newaxis, :].repeat(ntrace, axis=0)

    # y = hidden[0,:]
    # coeff0 = utils.robust_regression(x0in[usetrace], y, pnpc[1], 0.1, function=function)

    # y = hidden[1,:]
    # coeff1 = utils.robust_regression(x0in[usetrace], y, pnpc[2], 0.1, function=function)

    coeffstr = []
    for i in range(1, npc + 1):
        # if pnpc[i] == 0:
        #     coeffstr.append([-9.99E9])
        #     continue
        # coeff0 = utils.robust_regression(x0in[usetrace], hidden[i-1,:], pnpc[i], 0.1, function=function, min=x0in[0], max=x0in[-1])
        if weights is not None:
            tmask, coeff0 = utils.robust_polyfit(x0in[usetrace],
                                                 hidden[i - 1, :],
                                                 pnpc[i],
                                                 weights=weights[usetrace],
                                                 sigma=2.0,
                                                 function=function,
                                                 minx=x0in[0],
                                                 maxx=x0in[-1])
        else:
            tmask, coeff0 = utils.robust_polyfit(x0in[usetrace],
                                                 hidden[i - 1, :],
                                                 pnpc[i],
                                                 sigma=2.0,
                                                 function=function,
                                                 minx=x0in[0],
                                                 maxx=x0in[-1])
        coeffstr.append(coeff0)
        high_order_matrix[:, i - 1] = utils.func_val(coeff0,
                                                     x0in,
                                                     function,
                                                     minx=x0in[0],
                                                     maxx=x0in[-1])
    # high_order_matrix[:,1] = utils.func_val(coeff1, x0in, function)
    high_fit = high_order_matrix.copy()

    high_order_fit = np.dot(eigv, high_order_matrix.T)
    sub = (yfit - high_order_fit) * outmask

    numer = np.sum(sub, axis=0)
    denom = np.sum(outmask, axis=0)
    x0 = np.zeros(ntrace, dtype=np.float)
    fitmask = np.zeros(ntrace, dtype=np.float)
    #fitmask[mask] = 1
    x0fit = np.zeros(ntrace, dtype=np.float)
    chisqnu = 0.0
    chisqold = 0.0
    robust = True
    #svx0 = numer/(denom+(denom == 0).astype(np.int))
    if not skipx0:
        fitmask = (np.abs(denom) > 10).astype(np.int)
        if robust:
            good = np.where(fitmask != 0)[0]
            bad = np.where(fitmask == 0)[0]
            x0[good] = numer[good] / denom[good]
            imask = np.zeros(ntrace, dtype=np.float)
            imask[bad] = 1.0
            ttmask, x0res = utils.robust_polyfit(x0in,
                                                 x0,
                                                 pnpc[0],
                                                 weights=weights,
                                                 sigma=2.0,
                                                 function=function,
                                                 minx=x0in[0],
                                                 maxx=x0in[-1],
                                                 initialmask=imask)
            x0fit = utils.func_val(x0res,
                                   x0in,
                                   function,
                                   minx=x0in[0],
                                   maxx=x0in[-1])
            good = np.where(ttmask == 0)[0]
            xstd = 1.0  # This should represent the dispersion in the fit
            chisq = ((x0[good] - x0fit[good]) / xstd)**2.0
            chisqnu = np.sum(chisq) / np.sum(ttmask)
            fitmask = 1.0 - ttmask
            msgs.prindent("  Reduced chi-squared = {0:E}".format(chisqnu))
        else:
            for i in range(1, 5):
                good = np.where(fitmask != 0)[0]
                x0[good] = numer[good] / denom[good]
                #				x0res = utils.robust_regression(x0in[good],x0[good],pnpc[0],0.2,function=function)
                x0res = utils.func_fit(x0in[good],
                                       x0[good],
                                       function,
                                       pnpc[0],
                                       weights=weights,
                                       minx=x0in[0],
                                       maxx=x0in[-1])
                x0fit = utils.func_val(x0res,
                                       x0in,
                                       function,
                                       minx=x0in[0],
                                       maxx=x0in[-1])
                chisq = (x0[good] - x0fit[good])**2.0
                fitmask[good] *= (chisq < np.sum(chisq) / 2.0).astype(np.int)
                chisqnu = np.sum(chisq) / np.sum(fitmask)
                msgs.prindent("  Reduced chi-squared = {0:E}".format(chisqnu))
                if chisqnu == chisqold:
                    break
                else:
                    chisqold = chisqnu
        if chisqnu > 2.0:
            msgs.warn("PCA has very large residuals")
        elif chisqnu > 0.5:
            msgs.warn("PCA has fairly large residuals")
        #bad = np.where(fitmask==0)[0]
        #x0[bad] = x0fit[bad]
    else:
        x0res = 0.0
    x3fit = np.dot(eigv, high_order_matrix.T) + np.outer(x0fit,
                                                         np.ones(nrow)).T
    outpar = dict({
        'high_fit': high_fit,
        'x0': x0,
        'x0in': x0in,
        'x0fit': x0fit,
        'x0res': x0res,
        'x0mask': fitmask,
        'hidden': hidden,
        'usetrc': usetrace,
        'eigv': eigv,
        'npc': npc,
        'coeffstr': coeffstr
    })
    return x3fit, outpar
示例#8
0
def iterative_fitting(spec,
                      tcent,
                      ifit,
                      IDs,
                      llist,
                      disp,
                      match_toler=2.0,
                      func='legendre',
                      n_first=2,
                      sigrej_first=2.0,
                      n_final=4,
                      sigrej_final=3.0,
                      weights=None,
                      plot_fil=None,
                      verbose=False):
    """ Routine for iteratively fitting wavelength solutions.

    Parameters
    ----------
    spec : ndarray, shape = (nspec,)
      arcline spectrum
    tcent : ndarray
      Centroids in pixels of lines identified in spec
    ifit : ndarray
      Indices of the lines that will be fit
    IDs: ndarray
      wavelength IDs of the lines that will be fit (I think?)
    llist: dict
      Linelist dictionary
    disp: float
      dispersion

    Optional Parameters
    -------------------
    match_toler: float, default = 3.0
      Matching tolerance when searching for new lines. This is the difference in pixels between the wavlength assigned to
      an arc line by an iteration of the wavelength solution to the wavelength in the line list.
    func: str, default = 'legendre'
      Name of function used for the wavelength solution
    n_first: int, default = 2
      Order of first guess to the wavelength solution.
    sigrej_first: float, default = 2.0
      Number of sigma for rejection for the first guess to the wavelength solution.
    n_final: int, default = 4
      Order of the final wavelength solution fit
    sigrej_final: float, default = 3.0
      Number of sigma for rejection for the final fit to the wavelength solution.
    weights: ndarray
      Weights to be used?
    verbose : bool
      If True, print out more information.
    plot_fil:
      Filename for plotting some QA?

    Returns
    -------
    final_fit: dict
      Dictionary containing the full fitting results and the final best guess of the line IDs
    """

    #TODO JFH add error checking here to ensure that IDs and ifit have the same size!

    if weights is None:
        weights = np.ones(tcent.size)

    nspec = spec.size
    xnspecmin1 = float(nspec - 1)
    # Setup for fitting
    sv_ifit = list(ifit)  # Keep the originals
    all_ids = -999. * np.ones(len(tcent))
    all_idsion = np.array(['UNKNWN'] * len(tcent))
    all_ids[ifit] = IDs

    # Fit
    n_order = n_first
    flg_quit = False
    fmin, fmax = 0.0, 1.0
    while (n_order <= n_final) and (flg_quit is False):
        # Fit with rejection
        xfit, yfit, wfit = tcent[ifit], all_ids[ifit], weights[ifit]
        mask, fit = utils.robust_polyfit(xfit / xnspecmin1,
                                         yfit,
                                         n_order,
                                         function=func,
                                         sigma=sigrej_first,
                                         minx=fmin,
                                         maxx=fmax,
                                         verbose=verbose,
                                         weights=wfit)

        rms_ang = utils.calc_fit_rms(xfit[mask == 0] / xnspecmin1,
                                     yfit[mask == 0],
                                     fit,
                                     func,
                                     minx=fmin,
                                     maxx=fmax,
                                     weights=wfit[mask == 0])
        rms_pix = rms_ang / disp
        if verbose:
            msgs.info('n_order = {:d}'.format(n_order) +
                      ': RMS = {:g}'.format(rms_pix))

        # Reject but keep originals (until final fit)
        ifit = list(ifit[mask == 0]) + sv_ifit
        # Find new points (should we allow removal of the originals?)
        twave = utils.func_val(fit,
                               tcent / xnspecmin1,
                               func,
                               minx=fmin,
                               maxx=fmax)
        for ss, iwave in enumerate(twave):
            mn = np.min(np.abs(iwave - llist['wave']))
            if mn / disp < match_toler:
                imn = np.argmin(np.abs(iwave - llist['wave']))
                #if verbose:
                #    print('Adding {:g} at {:g}'.format(llist['wave'][imn],tcent[ss]))
                # Update and append
                all_ids[ss] = llist['wave'][imn]
                all_idsion[ss] = llist['ion'][imn]
                ifit.append(ss)
        # Keep unique ones
        ifit = np.unique(np.array(ifit, dtype=int))
        # Increment order
        if n_order < (n_final + 2):
            n_order += 1
        else:
            # This does 2 iterations at the final order
            flg_quit = True

    # Final fit (originals can now be rejected)
    #fmin, fmax = 0., 1.
    #xfit, yfit, wfit = tcent[ifit]/(nspec-1), all_ids[ifit], weights[ifit]
    xfit, yfit, wfit = tcent[ifit], all_ids[ifit], weights[ifit]
    mask, fit = utils.robust_polyfit(xfit / xnspecmin1,
                                     yfit,
                                     n_order,
                                     function=func,
                                     sigma=sigrej_final,
                                     minx=fmin,
                                     maxx=fmax,
                                     verbose=verbose,
                                     weights=wfit)  #, debug=True)
    irej = np.where(mask == 1)[0]
    if len(irej) > 0:
        xrej = xfit[irej]
        yrej = yfit[irej]
        if verbose:
            for kk, imask in enumerate(irej):
                wave = utils.func_val(fit,
                                      xrej[kk] / xnspecmin1,
                                      func,
                                      minx=fmin,
                                      maxx=fmax)
                msgs.info('Rejecting arc line {:g}; {:g}'.format(
                    yfit[imask], wave))
    else:
        xrej = []
        yrej = []

    #xfit = xfit[mask == 0]
    #yfit = yfit[mask == 0]
    #wfit = wfit[mask == 0]
    ions = all_idsion[ifit]
    #    ions = all_idsion[ifit][mask == 0]
    # Final RMS
    rms_ang = utils.calc_fit_rms(xfit[mask == 0] / xnspecmin1,
                                 yfit[mask == 0],
                                 fit,
                                 func,
                                 minx=fmin,
                                 maxx=fmax,
                                 weights=wfit[mask == 0])
    #    rms_ang = utils.calc_fit_rms(xfit, yfit, fit, func,
    #                                 minx=fmin, maxx=fmax, weights=wfit)
    rms_pix = rms_ang / disp

    # Pack up fit
    spec_vec = np.arange(nspec)
    wave_soln = utils.func_val(fit,
                               spec_vec / xnspecmin1,
                               func,
                               minx=fmin,
                               maxx=fmax)
    cen_wave = utils.func_val(fit,
                              float(nspec) / 2 / xnspecmin1,
                              func,
                              minx=fmin,
                              maxx=fmax)
    cen_wave_min1 = utils.func_val(fit, (float(nspec) / 2 - 1.0) / xnspecmin1,
                                   func,
                                   minx=fmin,
                                   maxx=fmax)
    cen_disp = cen_wave - cen_wave_min1

    final_fit = dict(fitc=fit,
                     function=func,
                     pixel_fit=xfit,
                     wave_fit=yfit,
                     weights=wfit,
                     ions=ions,
                     fmin=fmin,
                     fmax=fmax,
                     xnorm=xnspecmin1,
                     nspec=nspec,
                     cen_wave=cen_wave,
                     cen_disp=cen_disp,
                     xrej=xrej,
                     yrej=yrej,
                     mask=(mask == 0),
                     spec=spec,
                     wave_soln=wave_soln,
                     nrej=sigrej_final,
                     shift=0.,
                     tcent=tcent,
                     rms=rms_pix)

    # If set to True, this will output a file that can then be included in the tests
    saveit = False
    if saveit:
        from linetools import utils as ltu
        jdict = ltu.jsonify(final_fit)
        if plot_fil is None:
            outname = "temp"
            print(
                "You should have set the plot_fil directory to save wavelength fits... using 'temp' as a filename"
            )
        else:
            outname = plot_fil
        ltu.savejson(outname + '.json',
                     jdict,
                     easy_to_read=True,
                     overwrite=True)
        print(" Wrote: {:s}".format(outname + '.json'))

    # QA
    if plot_fil is not None:
        autoid.arc_fit_qa(final_fit, plot_fil)
    # Return
    return final_fit
示例#9
0
def flexure_qa_oldbuggyversion(specobjs,
                               maskslits,
                               basename,
                               det,
                               flex_list,
                               slit_cen=False):
    """ QA on flexure measurement

    Parameters
    ----------
    det
    flex_list : list
      list of dict containing flexure results
    slit_cen : bool, optional
      QA on slit center instead of objects

    Returns
    -------

    """
    plt.rcdefaults()
    plt.rcParams['font.family'] = 'times new roman'

    # Grab the named of the method
    method = inspect.stack()[0][3]
    #
    gdslits = np.where(~maskslits)[0]
    for sl in range(len(specobjs)):
        if sl not in gdslits:
            continue
        if specobjs[sl][0] is None:
            continue
        # Setup
        if slit_cen:
            nobj = 1
            ncol = 1
        else:
            nobj = len(specobjs[sl])
            ncol = min(3, nobj)
        #
        if nobj == 0:
            continue
        nrow = nobj // ncol + ((nobj % ncol) > 0)

        # Get the flexure dictionary
        flex_dict = flex_list[sl]

        # Outfile
        outfile = qa.set_qa_filename(basename,
                                     method + '_corr',
                                     det=det,
                                     slit=specobjs[sl][0].SLITID)

        plt.figure(figsize=(8, 5.0))
        plt.clf()
        gs = gridspec.GridSpec(nrow, ncol)

        # Correlation QA
        for o in range(nobj):
            ax = plt.subplot(gs[o // ncol, o % ncol])
            # Fit
            fit = flex_dict['polyfit'][o]
            xval = np.linspace(
                -10., 10,
                100) + flex_dict['corr_cen'][o]  #+ flex_dict['shift'][o]
            #model = (fit[2]*(xval**2.))+(fit[1]*xval)+fit[0]
            model = utils.func_val(fit, xval, 'polynomial')
            mxmod = np.max(model)
            ylim = [np.min(model / mxmod), 1.3]
            ax.plot(xval - flex_dict['corr_cen'][o], model / mxmod, 'k-')
            # Measurements
            ax.scatter(flex_dict['subpix'][o] - flex_dict['corr_cen'][o],
                       flex_dict['corr'][o] / mxmod,
                       marker='o')
            # Final shift
            ax.plot([flex_dict['shift'][o]] * 2, ylim, 'g:')
            # Label
            if slit_cen:
                ax.text(0.5,
                        0.25,
                        'Slit Center',
                        transform=ax.transAxes,
                        size='large',
                        ha='center')
            else:
                ax.text(0.5,
                        0.25,
                        '{:s}'.format(specobjs[sl][o].NAME),
                        transform=ax.transAxes,
                        size='large',
                        ha='center')
            ax.text(0.5,
                    0.15,
                    'flex_shift = {:g}'.format(flex_dict['shift'][o]),
                    transform=ax.transAxes,
                    size='large',
                    ha='center')  #, bbox={'facecolor':'white'})
            # Axes
            ax.set_ylim(ylim)
            ax.set_xlabel('Lag')

        # Finish
        plt.tight_layout(pad=0.2, h_pad=0.0, w_pad=0.0)
        plt.savefig(outfile, dpi=400)
        plt.close()

        # Sky line QA (just one object)
        if slit_cen:
            o = 0
        else:
            o = 0
            specobj = specobjs[sl][o]
        sky_spec = flex_dict['sky_spec'][o]
        arx_spec = flex_dict['arx_spec'][o]

        # Sky lines
        sky_lines = np.array([
            3370.0, 3914.0, 4046.56, 4358.34, 5577.338, 6300.304, 7340.885,
            7993.332, 8430.174, 8919.610, 9439.660, 10013.99, 10372.88
        ]) * units.AA
        dwv = 20. * units.AA
        gdsky = np.where((sky_lines > sky_spec.wvmin)
                         & (sky_lines < sky_spec.wvmax))[0]
        if len(gdsky) == 0:
            msgs.warn("No sky lines for Flexure QA")
            return
        if len(gdsky) > 6:
            idx = np.array(
                [0, 1, len(gdsky) // 2,
                 len(gdsky) // 2 + 1, -2, -1])
            gdsky = gdsky[idx]

        # Outfile
        outfile = qa.set_qa_filename(basename,
                                     method + '_sky',
                                     det=det,
                                     slit=specobjs[sl][0].SLITID)
        # Figure
        plt.figure(figsize=(8, 5.0))
        plt.clf()
        nrow, ncol = 2, 3
        gs = gridspec.GridSpec(nrow, ncol)
        if slit_cen:
            plt.suptitle('Sky Comparison for Slit Center', y=1.05)
        else:
            plt.suptitle('Sky Comparison for {:s}'.format(specobj.NAME),
                         y=1.05)

        for ii, igdsky in enumerate(gdsky):
            skyline = sky_lines[igdsky]
            ax = plt.subplot(gs[ii // ncol, ii % ncol])
            # Norm
            pix = np.where(np.abs(sky_spec.wavelength - skyline) < dwv)[0]
            f1 = np.sum(sky_spec.flux[pix])
            f2 = np.sum(arx_spec.flux[pix])
            norm = f1 / f2
            # Plot
            ax.plot(sky_spec.wavelength[pix],
                    sky_spec.flux[pix],
                    'k-',
                    label='Obj',
                    drawstyle='steps-mid')
            pix2 = np.where(np.abs(arx_spec.wavelength - skyline) < dwv)[0]
            ax.plot(arx_spec.wavelength[pix2],
                    arx_spec.flux[pix2] * norm,
                    'r-',
                    label='Arx',
                    drawstyle='steps-mid')
            # Axes
            ax.xaxis.set_major_locator(plt.MultipleLocator(dwv.value))
            ax.set_xlabel('Wavelength')
            ax.set_ylabel('Counts')

        # Legend
        plt.legend(loc='upper left',
                   scatterpoints=1,
                   borderpad=0.3,
                   handletextpad=0.3,
                   fontsize='small',
                   numpoints=1)

        # Finish
        plt.savefig(outfile, dpi=400)
        plt.close()
        #plt.close()

    plt.rcdefaults()
示例#10
0
def flexure_qa(specobjs,
               maskslits,
               basename,
               det,
               flex_list,
               slit_cen=False,
               out_dir=None):
    """

    Args:
        specobjs:
        maskslits (np.ndarray):
        basename (str):
        det (int):
        flex_list (list):
        slit_cen:
        out_dir:

    """
    plt.rcdefaults()
    plt.rcParams['font.family'] = 'times new roman'

    # Grab the named of the method
    method = inspect.stack()[0][3]
    #
    gdslits = np.where(np.invert(maskslits))[0]

    # Loop over slits, and then over objects here
    for slit in gdslits:
        indx = specobjs.slitorder_indices(slit)
        this_specobjs = specobjs[indx]
        this_flex_dict = flex_list[slit]

        # Setup
        if slit_cen:
            nobj = 1
            ncol = 1
        else:
            nobj = np.sum(indx)
            ncol = min(3, nobj)
        #
        if nobj == 0:
            continue
        nrow = nobj // ncol + ((nobj % ncol) > 0)
        # Outfile, one QA file per slit
        outfile = qa.set_qa_filename(basename,
                                     method + '_corr',
                                     det=det,
                                     slit=(slit + 1),
                                     out_dir=out_dir)
        plt.figure(figsize=(8, 5.0))
        plt.clf()
        gs = gridspec.GridSpec(nrow, ncol)
        for iobj, specobj in enumerate(this_specobjs):
            if specobj is None or len(specobj._data.keys()) == 1:
                continue
            # Correlation QA
            ax = plt.subplot(gs[iobj // ncol, iobj % ncol])
            # Fit
            fit = this_flex_dict['polyfit'][iobj]
            xval = np.linspace(-10., 10, 100) + this_flex_dict['corr_cen'][
                iobj]  #+ flex_dict['shift'][o]
            #model = (fit[2]*(xval**2.))+(fit[1]*xval)+fit[0]
            model = utils.func_val(fit, xval, 'polynomial')
            mxmod = np.max(model)
            ylim_min = np.min(model /
                              mxmod) if np.isfinite(np.min(model /
                                                           mxmod)) else 0.0
            ylim = [ylim_min, 1.3]
            ax.plot(xval - this_flex_dict['corr_cen'][iobj], model / mxmod,
                    'k-')
            # Measurements
            ax.scatter(this_flex_dict['subpix'][iobj] -
                       this_flex_dict['corr_cen'][iobj],
                       this_flex_dict['corr'][iobj] / mxmod,
                       marker='o')
            # Final shift
            ax.plot([this_flex_dict['shift'][iobj]] * 2, ylim, 'g:')
            # Label
            if slit_cen:
                ax.text(0.5,
                        0.25,
                        'Slit Center',
                        transform=ax.transAxes,
                        size='large',
                        ha='center')
            else:
                ax.text(0.5,
                        0.25,
                        '{:s}'.format(specobj.NAME),
                        transform=ax.transAxes,
                        size='large',
                        ha='center')
            ax.text(0.5,
                    0.15,
                    'flex_shift = {:g}'.format(this_flex_dict['shift'][iobj]),
                    transform=ax.transAxes,
                    size='large',
                    ha='center')  #, bbox={'facecolor':'white'})
            # Axes
            ax.set_ylim(ylim)
            ax.set_xlabel('Lag')
        # Finish
        plt.tight_layout(pad=0.2, h_pad=0.0, w_pad=0.0)
        plt.savefig(outfile, dpi=400)
        plt.close()

        # Sky line QA (just one object)
        if slit_cen:
            iobj = 0
        else:
            iobj = 0
            specobj = this_specobjs[iobj]

        if len(this_flex_dict['shift']) == 0:
            return

        # Repackage
        sky_spec = this_flex_dict['sky_spec'][iobj]
        arx_spec = this_flex_dict['arx_spec'][iobj]

        # Sky lines
        sky_lines = np.array([
            3370.0, 3914.0, 4046.56, 4358.34, 5577.338, 6300.304, 7340.885,
            7993.332, 8430.174, 8919.610, 9439.660, 10013.99, 10372.88
        ]) * units.AA
        dwv = 20. * units.AA
        gdsky = np.where((sky_lines > sky_spec.wvmin)
                         & (sky_lines < sky_spec.wvmax))[0]
        if len(gdsky) == 0:
            msgs.warn("No sky lines for Flexure QA")
            return
        if len(gdsky) > 6:
            idx = np.array(
                [0, 1, len(gdsky) // 2,
                 len(gdsky) // 2 + 1, -2, -1])
            gdsky = gdsky[idx]

        # Outfile
        outfile = qa.set_qa_filename(basename,
                                     method + '_sky',
                                     det=det,
                                     slit=(slit + 1),
                                     out_dir=out_dir)
        # Figure
        plt.figure(figsize=(8, 5.0))
        plt.clf()
        nrow, ncol = 2, 3
        gs = gridspec.GridSpec(nrow, ncol)
        if slit_cen:
            plt.suptitle('Sky Comparison for Slit Center', y=1.05)
        else:
            plt.suptitle('Sky Comparison for {:s}'.format(specobj.NAME),
                         y=1.05)

        for ii, igdsky in enumerate(gdsky):
            skyline = sky_lines[igdsky]
            ax = plt.subplot(gs[ii // ncol, ii % ncol])
            # Norm
            pix = np.where(np.abs(sky_spec.wavelength - skyline) < dwv)[0]
            f1 = np.sum(sky_spec.flux[pix])
            f2 = np.sum(arx_spec.flux[pix])
            norm = f1 / f2
            # Plot
            ax.plot(sky_spec.wavelength[pix],
                    sky_spec.flux[pix],
                    'k-',
                    label='Obj',
                    drawstyle='steps-mid')
            pix2 = np.where(np.abs(arx_spec.wavelength - skyline) < dwv)[0]
            ax.plot(arx_spec.wavelength[pix2],
                    arx_spec.flux[pix2] * norm,
                    'r-',
                    label='Arx',
                    drawstyle='steps-mid')
            # Axes
            ax.xaxis.set_major_locator(plt.MultipleLocator(dwv.value))
            ax.set_xlabel('Wavelength')
            ax.set_ylabel('Counts')

        # Legend
        plt.legend(loc='upper left',
                   scatterpoints=1,
                   borderpad=0.3,
                   handletextpad=0.3,
                   fontsize='small',
                   numpoints=1)

        # Finish
        plt.savefig(outfile, dpi=400)
        plt.close()
        #plt.close()

    plt.rcdefaults()

    return