def linear_mn_hght_bg(xvals,yvals,invals,sigma,mn_est,power=2):
    """ Find mean of gaussian with linearized procedure.  Use eqn
        dx = dh/h * sigma/2 * exp(1/2) where dh = max - min residual
        Once mean is found, determines best fit hght and bg
    """
    gauss_model = sf.gaussian(xvals,sigma,center=mn_est,height=np.max(yvals),power=power)
    mn_est_err = 100
    mn_est_err_old = 0
    loop_ct = 0
    while abs(mn_est_err-mn_est_err_old)>0.01 and loop_ct < 1:
        mn_est_err_old = np.copy(mn_est_err)
        mn_est_old = np.copy(mn_est)
        residuals = yvals-gauss_model
        dh = (np.max(residuals)-np.min(residuals))/np.max(yvals)
        sign = 1
        if np.argmax(residuals) < np.argmin(residuals):
            sign = -1
        dx = sign*sigma*dh*np.exp(1/2)/2
        mn_est += dx
        hght, bg = sf.best_linear_gauss(xvals,sigma,mn_est,yvals,invals,power=power)
        gauss_model = sf.gaussian(xvals,sigma,center=mn_est,height=hght,bg_mean=bg,power=power)
        mn_est_err = abs(mn_est_old - mn_est)
        loop_ct += 1
#    hght, bg = sf.best_linear_gauss(xvals,sigma,mn_est,yvals,invals,power=power)
    return mn_est, hght, bg
def remove_ccd_background(ccd,cut=None,plot=False):
    """ Use to remove diffuse background (not bias).
        Assumes a gaussian background.
        Returns ccd without zero mean background and
        the mean background error (1 sigma)
    """
    if cut is None:
        cut = 3*np.median(ccd)
    cut = int(cut)
    ccd_mask = (ccd < cut)*(ccd > -cut)
    masked_ccd = ccd[ccd_mask]
    arr = plt.hist(masked_ccd,2*(cut-1))
    hgt = arr[0]
    xvl = arr[1][:-1]
    ### Assume lower tail is a better indicator than upper tail
    xmsk = (xvl < np.median(masked_ccd))
    hgts = hgt[xmsk]
    xvls = xvl[xmsk]
    sig_est = 2/2.35*(xvls[np.argmax(hgts)] - xvls[np.argmax(hgts>np.max(hgts)/2)])
    pguess = (sig_est,np.median(masked_ccd),np.max(hgt))
    sigma = 1/np.sqrt(abs(hgts)+1)
    params, errarr = opt.curve_fit(sf.gaussian,xvls,hgts,p0=pguess,sigma=sigma)
    if plot:
        plt.title("Number of pixels with certain count value")
        htst = sf.gaussian(xvl, params[0], center=params[1], height=params[2],bg_mean=0,bg_slope=0,power=2)
        plt.plot(xvl,htst)
        plt.show()
    plt.close()
    ccd -= params[1] # mean
    bg_std = params[0]
    return ccd, bg_std
def arc_peaks(data,wvln,invar,ts,sampling_est=3,pad=4):
    """ Finds the position, wavelength, amplitude, width, etc. of distinct
        peaks along an extracted arc frame.
        INPUTS:
            data - extracted arc frame (a x b x c) a=telescope, b=fiber,
                   c=pixel
            wvln - wavelength corresponding to each point in 'data'
            invar - inverse variance of each point in 'data'
            ts - telescope number to use ('a' in data)
            sampling_est - estimated FWHM in pixels (must be integer)
            pad - width to fit to either side of gaussian max
        OUTPUTS:
            pos_d - position in pixels of each peak
            wl_d, mx_it_d, stddev_d, chi_d, err_d
        
    """
    #use dictionaries since number of peaks per fiber varies
    mx_it_d = dict() #max intensities of each peak
    stddev_d = dict() #FWHM of each peak
    wl_d = dict() #position of each peak in wavelength space
#    herm3_d = dict() #amplitude of third order hermite polynomial
#    herm4_d = dict() #amplitude of fourth order hermite polynomial
    pos_d = dict() #position of each peak in pixel space
    chi_d = dict() #reduced chi squared
    err_d = dict() #stddev parameter error
    for i in range(len(data[0,:,0])):
        ##Optional - use scipy to get initial guesses of peak locations
        ##Problem is this misses many of the low amplitude peaks.
        #pos_est = np.array(sig.find_peaks_cwt(data[ts,i,:],np.arange(3,4)))
        #Since spectrum has ~no background, can use my own peak finder.
        pos_est = np.zeros((len(data[ts,i,:])),dtype=int)
        for j in range(2*sampling_est,len(data[ts,i,:])-2*sampling_est):
            #if point is above both its neighbors, call it a peak
            if data[ts,i,j]>data[ts,i,j-1] and data[ts,i,j]>data[ts,i,j+1]:
                pos_est[j] = j
        #Then remove extra elements from pos_est
        pos_est = pos_est[np.nonzero(pos_est)[0]]
        #Cut out any that are within 2*sampling of each other (won't be able to fit well)
        pos_diff = ediff1d(pos_est)
        if np.count_nonzero(pos_diff<(2*sampling_est))>0:
            close_inds = np.nonzero(pos_diff<(2*sampling_est))[0]
        ### Try 1x sampling and see if that gives any more peaks in the end...
#        if np.count_nonzero(pos_diff<(1*sampling_est))>0:
#            close_inds = np.nonzero(pos_diff<(1*sampling_est))[0]
            close_inds = np.concatenate((close_inds,close_inds+1))
            close_inds = np.unique(close_inds)
            close_inds = np.sort(close_inds)
            pos_est = np.delete(pos_est,close_inds)
        #Also cut out any with a zero 1 pixels or less to either side
    #    tspl = pos_est-2
        ospl = pos_est-1
        ospr = pos_est+1
    #    tspr = pos_est+2
        zero_inds = np.zeros((len(pos_est)))
        for tt in range(len(zero_inds)):
    #        if tt < 6:
    #            print tt, ":"
    #            print data[ts,i,tspl[tt]]==0
    #            print data[ts,i,ospl[tt]]==0
    #            print data[ts,i,ospr[tt]]==0
    #            print data[ts,i,tspr[tt]]==0
            if data[ts,i,ospl[tt]]==0 or data[ts,i,ospr[tt]]==0:# or data[ts,i,tspl[tt]]==0 or data[ts,i,tspr[tt]]==0:
                zero_inds[tt]=1
    #    print zero_inds
    #    print pos_est
    #    plt.plot(data[0,0,:])
    #    plt.figure()
    #    plt.plot(pos_est)
    #    plt.show()
        pos_est = pos_est[zero_inds==0]
    #    if i == 0:
    #        print pos_est
        #variable length arrays to dump into dictionary
        num_pks = len(pos_est)
        mx_it = zeros((num_pks))
        stddev = zeros((num_pks))
#        herm3 = zeros((num_pks))
#        herm4 = zeros((num_pks))
        pos = zeros((num_pks))
        chi = zeros((num_pks))
        err = zeros((num_pks))
        pos_idx = zeros((num_pks),dtype=int)
    #    slp = zeros((num_pks))
        #Now fit gaussian with background to each (can improve function later)
        for j in range(num_pks):
            pos_idx[j] = pos_est[j]
            xarr = pos_est[j] + np.arange(-pad,pad,1)
            xarr = xarr[(xarr>0)*(xarr<2048)]
            yarr = data[ts,i,:][xarr]
            wlarr = wvln[ts,i,:][xarr]
            invarr = invar[ts,i,:][xarr]
            try:
                params, errarr = sf.gauss_fit(wlarr,yarr,invr=invarr,fit_background='n')
    #            params = sf.fit_gauss_herm1d(wlarr,yarr,invarr)
    #            errarr = np.diag(np.ones(len(params)))
            except RuntimeError:
                params = np.zeros(3)
    #            params = np.zeros(5)
                pos_idx[j] = 0
            tot = sf.gaussian(wlarr,abs(params[0]),params[1],params[2])#,params[3],params[4])
    #        tot = sf.gauss_herm1d(wlarr,abs(params[0]),params[1],params[2],params[3],params[4])
            chi_sq = sum((yarr-tot)**2*invarr)
            chi[j] = chi_sq/len(yarr)
            if chi_sq/len(yarr) > 10: #arbitrary cutoff
                params = np.zeros(5)
                pos_idx[j] = 0
            mx_it[j] = params[2] #height
            stddev[j] = params[0]#*2*sqrt(2*log(2)) #converted from std dev
    #        herm3[j] = params[3]
    #        herm4[j] = params[4]
            err[j] = np.sqrt(errarr[0,0])
            pos[j] = params[1] #center
    #        slp[j] = params[4] #bg_slope
        mx_it_d[i] = mx_it[np.nonzero(pos)[0]] #Remove zero value points
        stddev_d[i] = stddev[np.nonzero(pos)[0]]
    #    herm3_d[i] = herm3[np.nonzero(pos)[0]]
    #    herm4_d[i] = herm4[np.nonzero(pos)[0]]
        wl_d[i] = pos[np.nonzero(pos)[0]]
        pos_d[i] = pos_idx[np.nonzero(pos)[0]]
        plt.show()
        chi_d[i] = chi[np.nonzero(pos)[0]]
        err_d[i] = err[np.nonzero(pos)[0]]
    #    if i == 0:
    #        plt.plot(pos_idx,data[ts,i,:][pos_idx],'ks')
    #        plt.plot(data[ts,i,:])
    #        plt.show()
    return pos_d, wl_d, mx_it_d, stddev_d, chi_d, err_d
def fit_trace(x,y,ccd,form='gaussian'):
    """quadratic fit (in x) to trace around x,y in ccd
       x,y are integer pixel values
       input "form" can be set to quadratic or gaussian
    """
    x = int(x)
    y = int(y)
    if form=='quadratic':
        xpad = 2
        xvals = np.arange(-xpad,xpad+1)
        def make_chi_profile(x,y,ccd):
            xpad = 2
            xvals = np.arange(-xpad,xpad+1)
            zvals = ccd[x+xvals,y]
            profile = np.ones((2*xpad+1,3)) #Quadratic fit
            profile[:,1] = xvals
            profile[:,2] = xvals**2
            noise = np.diag((1/zvals))
            return zvals, profile, noise
        zvals, profile, noise = make_chi_profile(x,y,ccd)
        coeffs, chi = sf.chi_fit(zvals,profile,noise)
    #    print x
    #    print xvals
    #    print x+xvals
    #    print zvals
    #    plt.errorbar(x+xvals,zvals,yerr=sqrt(zvals))
    #    plt.plot(x+xvals,coeffs[2]*xvals**2+coeffs[1]*xvals+coeffs[0])
    #    plt.show()
        chi_max = 100
        if chi>chi_max:
            #print("bad fit, chi^2 = {}".format(chi))
            #try adacent x
            xl = x-1
            xr = x+1
            zl, pl, nl = make_chi_profile(xl,y,ccd)
            zr, pr, nr = make_chi_profile(xr,y,ccd)
            cl, chil = sf.chi_fit(zl,pl,nl)
            cr, chir = sf.chi_fit(zr,pr,nr)
            if chil<chi and chil<chir:
    #            plt.errorbar(xvals-1,zl,yerr=sqrt(zl))
    #            plt.plot(xvals-1,cl[2]*(xvals-1)**2+cl[1]*(xvals-1)+cl[0])
    #            plt.show()
                xnl = -cl[1]/(2*cl[2])
                znl = cl[2]*xnl**2+cl[1]*xnl+cl[0]
                return xl+xnl, znl, chil
            elif chir<chi and chir<chil:
                xnr = -cr[1]/(2*cr[2])
                znr = cr[2]*xnr**2+cr[1]*xnr+cr[0]
    #            plt.errorbar(xvals+1,zr,yerr=sqrt(zr))
    #            plt.plot(xvals+1,cr[2]*(xvals+1)**2+cr[1]*(xvals+1)+cr[0])
    #            plt.show()
                return xr+xnr, znr, chir
            else:
                ca = coeffs[2]
                cb = coeffs[1]
                xc = -cb/(2*ca)
                zc = ca*xc**2+cb*xc+coeffs[0]
                return x+xc, zc, chi
        else:
            ca = coeffs[2]
            cb = coeffs[1]
            xc = -cb/(2*ca)
            zc = ca*xc**2+cb*xc+coeffs[0]
            return x+xc, zc, chi
    elif form=='gaussian':
        xpad = 7
        xvals = np.arange(-xpad,xpad+1)
        xinds = x+xvals
        xvals = xvals[(xinds>=0)*(xinds<np.shape(ccd)[0])]
        zvals = ccd[x+xvals,y]
        params, errarr = sf.gauss_fit(xvals,zvals)
        xc = x+params[1] #offset plus center
        zc = params[2] #height (intensity)
#        pxn = np.linspace(xvals[0],xvals[-1],1000)
        fit = sf.gaussian(xvals,abs(params[0]),params[1],params[2],params[3],params[4])
        chi = sum((fit-zvals)**2/zvals)
        return xc, zc, chi
def refine_trace_centers(ccd, t_coeffs, i_coeffs, s_coeffs, p_coeffs, fact=10, readnoise=3.63, verbose=False):
    """ Uses estimated centers from fibers flats as starting point, then
        fits from there to find traces based on science ccd frame.
        INPUTS:
            ccd - image on which to fit traces
            t/i/s/p_coeffs - modified gaussian coefficients from fiberflat
            fact - do 1/fact of the available points
    """
    num_fibers = t_coeffs.shape[0]
    hpix = ccd.shape[1]
    vpix = ccd.shape[0]
    ### First fit vc parameters for traces
    rough_pts = int(np.ceil(hpix/fact))
    vc_ccd = np.zeros((num_fibers,rough_pts))
    hc_ccd = np.zeros((num_fibers,rough_pts))
    inv_chi = np.zeros((num_fibers,rough_pts))
    yspec = np.arange(hpix)
    if verbose:
        print("Refining trace centers")
    for i in range(num_fibers):
        if verbose:
            print("Running on index {}".format(i))
    #    slit_num = np.floor((i)/args.telescopes)
        for j in range(0,hpix,fact):
            jadj = int(np.floor(j/fact))
            yj = (yspec[j]-hpix/2)/hpix
            hc_ccd[i,jadj] = yspec[j]
            vc = t_coeffs[2,i]*yj**2+t_coeffs[1,i]*yj+t_coeffs[0,i]
#            Ij = i_coeffs[2,i]*yj**2+i_coeffs[1,i]*yj+i_coeffs[0,i]
            sigj = s_coeffs[2,i]*yj**2+s_coeffs[1,i]*yj+s_coeffs[0,i]
            powj = s_coeffs[2,i]*yj**2+s_coeffs[1,i]*yj+s_coeffs[0,i]
            if np.isnan(vc):
                vc_ccd[i,jadj] = np.nan
                inv_chi[i,jadj] = 0
            else:
                xpad = 7
                xvals = np.arange(-xpad,xpad+1)
                xj = int(vc)
                xwindow = xj+xvals
                xvals = xvals[(xwindow>=0)*(xwindow<vpix)]
                zorig = ccd[xj+xvals,yspec[j]]
                if len(zorig)<1:
                    vc_ccd[i,jadj] = np.nan
                    inv_chi[i,jadj] = 0
                    continue
                invorig = 1/(abs(zorig)+readnoise**2)
                if np.max(zorig)<20:
                    vc_ccd[i,jadj] = np.nan
                    inv_chi[i,jadj] = 0
                else:
                    mn_new, hght, bg = fit_mn_hght_bg(xvals, zorig, invorig, sigj, vc-xj-1, sigj, powj=powj)
                    fitorig = sf.gaussian(xvals,sigj,mn_new,hght,power=powj)
                    if j == 715:
                        print mn_new
                        plt.plot(xvals,zorig,xvals,fitorig)
                        plt.show()
                        plt.close()
                    inv_chi[i,jadj] = 1/sum((zorig-fitorig)**2*invorig)
                    vc_ccd[i,jadj] = mn_new+xj+1
                    
    
    tmp_poly_ord = 10
    trace_coeffs_ccd = np.zeros((tmp_poly_ord+1,num_fibers))
    for i in range(num_fibers):
        mask = ~np.isnan(vc_ccd[i,:])
        profile = np.ones((len(hc_ccd[i,:][mask]),tmp_poly_ord+1)) #Quadratic fit
        for order in range(tmp_poly_ord):
            profile[:,order+1] = ((hc_ccd[i,:][mask]-hpix/2)/hpix)**(order+1)
        noise = np.diag(inv_chi[i,:][mask])
        if len(vc_ccd[i,:][mask])>3:
            tmp_coeffs, junk = sf.chi_fit(vc_ccd[i,:][mask],profile,noise)
        else:
            tmp_coeffs = np.nan*np.ones((tmp_poly_ord+1))
        trace_coeffs_ccd[:,i] = tmp_coeffs 
    return trace_coeffs_ccd
def extract_1D(ccd, t_coeffs, i_coeffs=None, s_coeffs=None, p_coeffs=None, readnoise=1, gain=1, return_model=False, verbose=False):
    """ Function to extract using optimal extraction method.
        This could benefit from a lot of cleaning up
        INPUTS:
        ccd - ccd image to extract
        t_coeffs - estimate of trace coefficients (from 'find_t_coeffs')
        i/s/p_coeffs - optional intensity, sigma, power coefficients
        readnoise, gain - of the ccd
        return_model - set True to return model of image based on extraction
        OUTPUTS:
        spec - extracted spectrum (n x hpix) where n is number of traces
        spec_invar - inverse variance at each point in extracted spectrum
        spec_mask - mask for invalid/suspect points in spectrum
        image_model - only if return_model = True. 
    """
#    def extract(ccd,t_coeffs,i_coeffs=None,s_coeffs=None,p_coeffs=None,readnoise=1,gain=1,return_model=False,fact,verbose=False):
#        """ Extraction.
#        """
    
    ### t_coeffs are from fiber flat - need to shift based on actual exposure
        
    ####################################################
    ###   Prep Needed variables/empty arrays   #########
    ####################################################
    ### CCD dimensions and number of fibers
    hpix = np.shape(ccd)[1]
    vpix = np.shape(ccd)[0]
    num_fibers = np.shape(t_coeffs)[1]

    ####################################################    
    #####   First refine horizontal centers (fit   #####
    #####   traces from data ccd using fiber flat  #####
    #####   as initial estimate)                   #####
    ####################################################
    ta = time.time()  ### Start time of trace refinement
    fact = 20 #do 1/fact * available points
    ### Empty arrays
    rough_pts = int(np.ceil(hpix/fact))
    xc_ccd = np.zeros((num_fibers,rough_pts))
    yc_ccd = np.zeros((num_fibers,rough_pts))
    inv_chi = np.zeros((num_fibers,rough_pts))
    if verbose:
        print("Refining trace centers")
    for i in range(num_fibers):
        for j in range(0,hpix,fact):
            ### set coordinates, gaussian parameters from coeffs
            jadj = int(np.floor(j/fact))
            yj = (j-hpix/2)/hpix
            yc_ccd[i,jadj] = j
            xc = t_coeffs[2,i]*yj**2+t_coeffs[1,i]*yj+t_coeffs[0,i]
#            Ij = i_coeffs[2,i]*yj**2+i_coeffs[1,i]*yj+i_coeffs[0,i] #May use later for normalization
            sigj = s_coeffs[2,i]*yj**2+s_coeffs[1,i]*yj+s_coeffs[0,i]
            powj = p_coeffs[2,i]*yj**2+p_coeffs[1,i]*yj+p_coeffs[0,i]
            ### Don't try to fit any bad trace sections
            if np.isnan(xc):
                xc_ccd[i,jadj] = np.nan
                inv_chi[i,jadj] = 0
            else:
                ### Take subset of ccd of interest, xpad pixels to each side of peak
                xpad = 7
                xvals = np.arange(-xpad,xpad+1)
                xj = int(xc)
                xwindow = xj+xvals
                xvals = xvals[(xwindow>=0)*(xwindow<vpix)]
                zorig = gain*ccd[xj+xvals,j]
                ### If empty slice, don't try to fit
                if len(zorig)<1:
                    xc_ccd[i,jadj] = np.nan
                    inv_chi[i,jadj] = 0
                    continue
                invorig = 1/(abs(zorig)+readnoise**2) ### inverse variance
                ### Don't try to fit profile for very low SNR peaks
                if np.max(zorig)<20:
                    xc_ccd[i,jadj] = np.nan
                    inv_chi[i,jadj] = 0
                else:
                    ### Fit for center (mn_new), amongst other values
#                    mn_new, hght, bg = fit_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,sigj,powj=powj)
                    mn_new, hght, bg = linear_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,power=powj)
                    fitorig = sf.gaussian(xvals,sigj,mn_new,hght,power=powj)
                    inv_chi[i,jadj] = 1/sum((zorig-fitorig)**2*invorig)
                    ### Shift from relative to absolute center
                    xc_ccd[i,jadj] = mn_new+xj+1
                   
    #####################################################
    #### Now with new centers, refit trace coefficients #
    #####################################################
    tmp_poly_ord = 6  ### Use a higher order for a closer fit over entire trace
    t_coeffs_ccd = np.zeros((tmp_poly_ord+1,num_fibers))
    for i in range(num_fibers):
        #Given orientation makes more sense to swap x/y
        mask = ~np.isnan(xc_ccd[i,:]) ### Mask bad points
        ### build profile matrix over good points
        profile = np.ones((len(yc_ccd[i,:][mask]),tmp_poly_ord+1))
        for order in range(tmp_poly_ord):
            profile[:,order+1] = ((yc_ccd[i,:][mask]-hpix/2)/hpix)**(order+1)
        noise = np.diag(inv_chi[i,:][mask])
        if len(xc_ccd[i,:][mask])>(tmp_poly_ord+1):
            ### Chi^2 fit
            tmp_coeffs, junk = sf.chi_fit(xc_ccd[i,:][mask],profile,noise)
        else:
            ### if not enough points to fit, call entire trace bad
            tmp_coeffs = np.nan*np.ones((tmp_poly_ord+1))
        t_coeffs_ccd[:,i] = tmp_coeffs

    tb = time.time() ### Start time of extraction/end of trace refinement
    if verbose:
        print("Trace refinement time = {}s".format(tb-ta))
       
    ### Uncomment below to see plot of traces
#    for i in range(num_fibers):
#        ys = (np.arange(hpix)-hpix/2)/hpix
#        xs = t_coeffs_ccd[2,i]*ys**2+t_coeffs_ccd[1,i]*ys+t_coeffs_ccd[0,i]
#        yp = np.arange(hpix)
#        plt.plot(yp,xs)
#    plt.show()
#    plt.close()
    
    ###########################################################
    ##### Finally, full extraction with refined traces ########
    ###########################################################
    
    ### Make empty arrays for return values
    spec = np.zeros((num_fibers,hpix))
    spec_invar = np.zeros((num_fibers,hpix))
    spec_mask = np.ones((num_fibers,hpix),dtype=bool)
    chi2red_array = np.zeros((num_fibers,hpix))
    if return_model:
        image_model = np.zeros((np.shape(ccd))) ### Used for evaluation
    ### Run once for each fiber
    for i in range(num_fibers):
        #slit_num = np.floor((i)/4)#args.telescopes) # Use with slit flats
        if verbose:
            print("extracting trace {}".format(i+1))
        ### in each fiber loop run through each trace
        for j in range(hpix):
            yj = (j-hpix/2)/hpix
            xc = np.poly1d(t_coeffs_ccd[::-1,i])(yj)
#            Ij = i_coeffs[2,i]*yj**2+i_coeffs[1,i]*yj+i_coeffs[0,i]
            sigj = s_coeffs[2,i]*yj**2+s_coeffs[1,i]*yj+s_coeffs[0,i]
            powj = p_coeffs[2,i]*yj**2+p_coeffs[1,i]*yj+p_coeffs[0,i]
            ### If trace center is undefined mask the point
            if np.isnan(xc):
                spec_mask[i,j] = False
            else:
                ### Set values to use in extraction
                xpad = 5  ### can't be too big or traces start to overlap
                xvals = np.arange(-xpad,xpad+1)
                xj = int(xc)
                xwindow = xj+xvals
                xvals = xvals[(xwindow>=0)*(xwindow<vpix)]
                zorig = gain*ccd[xj+xvals,j]
                fitorig = sf.gaussian(xvals,sigj,xc-xj-1,hght,power=powj)
                ### If too short, don't fit, mask point
                if len(zorig)<1:
                    spec[i,j] = 0
                    spec_mask[i,j] = False
                    continue
                invorig = 1/(abs(zorig)+readnoise**2)
                ### don't try to extract for very low signal
                if np.max(zorig)<20:
                    continue
                else:
                    ### Do nonlinear fit for center, height, and background
                    mn_new, hght, bg = fit_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,sigj/8,powj=powj)
#                    mn_new, hght, bg = linear_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,power=powj)
                    ### Use fitted values to make best fit arrays
                    fitorig = sf.gaussian(xvals,sigj,mn_new,hght,power=powj)
                    xprecise = np.linspace(xvals[0],xvals[-1],100)
                    fitprecise = sf.gaussian(xprecise,sigj,mn_new,hght,power=powj)
                    ftmp = sum(fitprecise)*np.mean(np.ediff1d(xprecise))
                    #Following if/else handles failure to fit
                    if ftmp==0:
                        fitnorm = np.zeros(len(zorig))
                    else:
                        fitnorm = fitorig/ftmp
                    ### Get extracted flux and error
                    fstd = sum(fitnorm*zorig*invorig)/sum(fitnorm**2*invorig)
                    invorig = 1/(readnoise**2 + abs(fstd*fitnorm))
                    chi2red = np.sum((fstd*fitnorm+bg-zorig)**2*invorig)/(len(zorig)-3)
                    ### Now set up to do cosmic ray rejection
                    rej_min = 0
                    loop_count=0
                    while rej_min==0:
                        pixel_reject = cosmic_ray_reject(zorig,fstd,fitnorm,invorig,S=bg,threshhold=0.3*np.mean(zorig),verbose=True)
                        rej_min = np.min(pixel_reject)
                        ### Once no pixels are rejected, re-find extracted flux
                        if rej_min==0:
                            ### re-index arrays to remove rejected points
                            zorig = zorig[pixel_reject==1]
                            invorig = invorig[pixel_reject==1]
                            xvals = xvals[pixel_reject==1]
                            ### re-do fit (can later cast this into a separate function)
                            mn_new, hght, bg = fit_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,sigj/8,powj=powj)
#                            mn_new, hght, bg = linear_mn_hght_bg(xvals,zorig,invorig,sigj,xc-xj-1,power=powj)
                            fitorig = sf.gaussian(xvals,sigj,mn_new,hght,power=powj)
                            xprecise = np.linspace(xvals[0],xvals[-1],100)
                            fitprecise = sf.gaussian(xprecise,sigj,mn_new,hght,power=powj)
                            ftmp = sum(fitprecise)*np.mean(np.ediff1d(xprecise))
                            fitnorm = fitorig/ftmp
                            fstd = sum(fitnorm*zorig*invorig)/sum(fitnorm**2*invorig)
                            invorig = 1/(readnoise**2 + abs(fstd*fitnorm))
                            chi2red = np.sum((fstd*fitnorm+bg-zorig)**2*invorig)/(len(zorig)-3)
                        ### if more than 3 points are rejected, mask the extracted flux
                        if loop_count>3:
                            spec_mask[i,j] = False
                            break
                        loop_count+=1
                    ### Set extracted spectrum value, inverse variance
                    spec[i,j] = fstd
                    spec_invar[i,j] = sum(fitnorm**2*invorig)
                    chi2red_array[i,j] = chi2red
                    if return_model and not np.isnan(fstd):
                        ### Build model, if desired
                        image_model[xj+xvals,j] += (fstd*fitnorm+bg)/gain
            ### If a nan came out of the above routine, zero it and mask
            if np.isnan(spec[i,j]):
                spec[i,j] = 0
                spec_mask[i,j] = False
    if verbose:
        print("Average reduced chi^2 = {}".format(np.mean(chi2red)))
    if return_model:
        return spec, spec_invar, spec_mask, image_model
    else:
        return spec, spec_invar, spec_mask