def spline_coeff_eval(raw_img,hcenters,hc_ref,vcenters,vc_ref,invar,r_breakpoints,spline_poly_coeffs,s_scale,sigmas,powers,cpad=5,bp_space=2,full_image=True,view_plot=False,sp_coeffs=None,ecc_pa_coeffs=None):
    """ Another highly specialized function. Evaluates coeffs found in
        sf.interpolate_coeffs.  Right now, displays images of the fit
        compared to data.
        Returns spline_fit
    """
    voff = 1
    for k in range(len(vcenters)):
        if full_image:
            hmean = hcenters[k]
            vmean = vcenters[k]
            small_img = raw_img
            small_inv = invar
        else:
            harr = np.arange(-cpad,cpad+1)+hcenters[k]
            varr = np.arange(-cpad,cpad+1)+vcenters[k]
            small_img = raw_img[varr[0]:varr[-1]+1,harr[0]:harr[-1]+1]
            small_inv = invar[varr[0]:varr[-1]+1,harr[0]:harr[-1]+1]
            hmean, hheight, hbg = sf.fit_mn_hght_bg(harr,small_img[cpad,:],small_inv[cpad,:],sigmas[k],hcenters[k],sigmas[k],powj=powers[k])
            vmean, vheight, vbg = sf.fit_mn_hght_bg(varr+voff,small_img[:,cpad],small_inv[:,cpad],sigmas[k],vcenters[k],sigmas[k],powj=powers[k])
        hdec, hint = math.modf(hmean)
        vdec, vint = math.modf(vmean)
#        small_img = recenter_img(small_img,[vmean,hmean],[varr[0]+cpad,harr[0]+cpad])
    #    print "Mean is [", vmean, hmean, "]"
    #    plt.imshow(small_img,extent=(-cpad+hcenters[k],cpad+hcenters[k]+1,-cpad+vcenters[k],cpad+vcenters[k]+1),interpolation='none')
    #    plt.show()
    #    plt.close()
#        r_breakpoints = [0, 1, 2, 3, 4, 5, 9]
#        r_breakpoints = [0, 1.2, 2.5, 3.5, 5, 9]
#        theta_orders=[0,-2,2]
        theta_orders = [0]
#        spline_fit = spline.spline_2D_radial(small_img,small_inv,)
#        v_bpts = varr[np.mod(np.arange(len(varr)),bp_space)==0]-vcenters[k]
#        h_bpts = harr[np.mod(np.arange(len(harr)),bp_space)==0]-hcenters[k]
        spline_coeffs = np.zeros((len(spline_poly_coeffs[0])))
        for l in range(len(spline_poly_coeffs[0])):
            spline_coeffs[l] = sf.eval_polynomial_coeffs(hcenters[k],spline_poly_coeffs[:,l])
        ecc_pa = np.ones((2))
        if ecc_pa_coeffs is not None:
            for m in range(2):
                ecc_pa[m] = sf.eval_polynomial_coeffs(hcenters[k],ecc_pa_coeffs[m])
                
#        if k==0:
#            print spline_coeffs
        params = lmfit.Parameters()
        params.add('vc', value = vmean-vc_ref)
        params.add('hc', value = hmean-hc_ref)
        params.add('q',ecc_pa[0])
        params.add('PA',ecc_pa[1])
        spline_fit = spline.spline_2D_radial(small_img,small_inv,r_breakpoints,params,theta_orders=theta_orders,spline_coeffs=spline_coeffs,sscale=s_scale[k])
        if view_plot:
            plt.close()
            res = (small_img-spline_fit)*np.sqrt(small_inv)
            print("Chi^2 reduced = {}".format(np.sum(res**2)/(np.size(res)-2-len(spline_poly_coeffs[0]))))
            vis = np.vstack((small_img,spline_fit,res))
            plt.imshow(vis,interpolation='none')
#            plt.imshow(spline_fit,interpolation='none')
            plt.show()
            plt.close()
            plt.plot(spline_fit[:,5])
            plt.show()
            plt.close()
##        plt.figure()
##        plt.hist(np.ravel((small_img-spline_fit)*np.sqrt(small_inv)))
##        plt.imshow((small_img-spline_fit)*small_inv,interpolation='none')
#        chi2 = np.sum((small_img-spline_fit)**2*small_inv)/(np.size(small_img)-len(spline_coeffs))
#        print("Chi^2 = {}".format(chi2))
#        plt.show()
#        plt.close()
        return spline_fit
def extract_2D(ccd, psf_coeffs, t_coeffs, i_coeffs=None, s_coeffs=None, p_coeffs=None, readnoise=1, gain=1, return_model=False, verbose=False):
    """ Code to perform 2D spectroperfectionism algorithm on MINERVA data.
    """
    ### Set shape variables based on inputs
    num_fibers = t_coeffs.shape[0]
    hpix = ccd.shape[1]
    hscale = (np.arange(hpix)-hpix/2)/hpix
    extracted_counts = np.zeros((num_fibers,hpix))
    ### Remove CCD diffuse background - cut value matters
    cut = np.median(np.median(ccd[ccd<np.median(ccd)]))
    ccd, bg_err = remove_ccd_background(ccd,cut=cut,plot=True)
    ### Fit input trace coeffs (from fiberflat) to this ccd
    t_coeffs = refine_trace_centers(ccd,t_coeffs,i_coeffs,s_coeffs,p_coeffs)
    ### Parameters for extraction box size - try various values
    ### For meaning, see documentation
    num_sections = 16
    len_section = 143
    fit_pad = 4
    v_pad = 6
    len_edge = fit_pad*2
    ### iterate over all fibers
    for fib in range(num_fibers):
        print("Running 2D Extraction on fiber {}".format(fib))
        ### Trace parameters
        vcents = sf.eval_polynomial_coeffs(hscale,t_coeffs[:,fib])
        sigmas = sf.eval_polynomial_coeffs(hscale,s_coeffs[:,fib])
        powers = sf.eval_polynomial_coeffs(hscale,p_coeffs[:,fib])   
        ### PSF parameters
        ellipse = psf_coeffs[fib,-7:-1]
        ellipse = ellipse.reshape((2,3))
        params = array_to_params(ellipse)
        coeff_matrix = psf_coeffs[fib,:-7]
        coeff_matrix = coeff_matrix.reshape((coeff_matrix.size/3,3))
        for sec in range(num_sections):
            ### Get a small section of ccd to extract
            hsec = np.arange(sec*(len_section-2*len_edge), len_section+sec*(len_section-2*len_edge))
            vcent = np.mean(vcents[hsec])
            ccd_sec = ccd[vcent-v_pad:vcent+v_pad,hsec]
            ccd_sec_invar = 1/(ccd_sec + bg_err**2)
            ### set coordinates for opposite corners of box (for profile matrix)
            vtl = vcent-v_pad
            htl = hsec[0]
            vbr = vcent+v_pad
            hbr = hsec[-1]
            ### Optional - test removing background again
            ccd_sec, sec_bg_err = remove_ccd_background(ccd_sec,cut=3*bg_err)
            ### numbe of wavelength points to extract, default 1/pixel
            wls = len_section
            hcents = np.linspace(0,hsec[-1],wls)
            A = np.zeros((wls,2*v_pad+1,len(hsec)))
            for jj in range(wls):
                ### Commented lines are if wl_pad is used
#                if jj < 0:
#                    hcent = hcents[0]+jj*dlth
#                    vcent = sf.eval_polynomial_coeffs((hcent-hpix/2)/hpix, trace_coeffs_ccd[:,idx])[0]
#                elif jj >= wls:
#                    hcent = hcents[-1]+(jj-wls+1)*dlth
#                    vcent = sf.eval_polynomial_coeffs((hcent-hpix/2)/hpix, trace_coeffs_ccd[:,idx])[0]
#                else:
#                    hcent = hcents[jj]
#                    vcent = vcents[jj]
                hcent = hcents[jj]
                vcent = vcents[jj]
                vcent -= 1  ### Something is wrong above - shouldn't need this...
                center = [np.mod(hcent,1),np.mod(vcent,1)]
                hpoint = (hcent-hpix/2)/hpix
                ### Now build PSF model around center point
                psf_type = 'bspline'
                if psf_type == 'bspline':
                    ### TODO - revamp this to pull from input
                    r_breakpoints = np.hstack(([0, 1.5, 2.4, 3],np.arange(3.5,8.6,1)))         
                    theta_orders = [0]
                    psf_jj = spline.make_spline_model(params, coeff_matrix, center, hpoint, [2*fit_pad+1,2*fit_pad+1], r_breakpoints, theta_orders, fit_bg=False)
                    bg_lvl = np.median(psf_jj[psf_jj<np.mean(psf_jj)])
                    psf_jj -= bg_lvl  
                    psf_jj /= np.sum(psf_jj) # Normalize to 1
                sp_l = max(0,fit_pad+(htl-int(hcent))) #left edge
                sp_r = min(2*fit_pad+1,fit_pad+(hbr-int(hcent))) #right edge
                sp_t = max(0,fit_pad+(vtl-int(vcent))) #top edge
                sp_b = min(2*fit_pad+1,fit_pad+(vbr-int(vcent))) #bottom edge
                ### indices of A slice to use
                a_l = max(0,int(hcent)-htl-fit_pad) # left edge
                a_r = min(A.shape[2],int(hcent)-htl+fit_pad+1) # right edge
                a_t = max(0,int(vcent)-vtl-fit_pad) # top edge
                a_b = min(A.shape[1],int(vcent)-vtl+fit_pad+1) # bottom edge    
                A[jj+wl_pad,a_t:a_b,a_l:a_r] = psf_jj[sp_t:sp_b,sp_l:sp_r]  
            ##Now using the full available data
            B = np.matrix(np.resize(A.T,(d0*d1,wls)))
            B = np.hstack((B,np.ones((d0*d1,1)))) ### add background term
            p = np.matrix(np.resize(ccd_sec.T,(d0*d1,1)))
            n = np.diag(np.resize(ccd_sec_invar.T,(d0*d1,)))
            #print np.shape(B), np.shape(p), np.shape(n)
            text_sp_st = time.time()
            fluxtilde2 = sf.extract_2D_sparse(p,B,n)
            t_betw_ext = time.time()
            #fluxtilde3 = sf.extract_2D(p,B,n)
            tfinish = time.time()
    print "Total Time = ", tfinish-tstart
    print("PSF modeling took {}s".format(text_sp_st-tstart))
    print("Sparse extraction took {}s".format(t_betw_ext-text_sp_st))
    #print("Regular extraction took {}s".format(tfinish-t_betw_ext))
    flux2 = sf.extract_2D_sparse(p,B,n,return_no_conv=True)
    #Ninv = np.matrix(np.diag(np.resize(ccd_small_invar.T,(d0*d1,))))
    #Cinv = B.transpose()*Ninv*B
    #U, s, Vt = linalg.svd(Cinv)
    #Cpsuedo = Vt.transpose()*np.matrix(np.diag(1/s))*U.transpose();
    #flux2 = Cpsuedo*(B.transpose()*Ninv*p)
    #
    #d, Wt = linalg.eig(Cinv)
    #D = np.matrix(np.diag(np.asarray(d)))
    #WtDhW = Wt*np.sqrt(D)*Wt.transpose()
    #
    #WtDhW = np.asarray(WtDhW)
    #s = np.sum(WtDhW,axis=1)
    #S = np.matrix(np.diag(s))
    #Sinv = linalg.inv(S)
    #WtDhW = np.matrix(WtDhW)
    #R = Sinv*WtDhW
    #fluxtilde2 = R*flux2
    #fluxtilde2 = np.asarray(fluxtilde2)
    #flux2 = np.asarray(flux2)
    
    img_est = np.dot(B,flux2)
    img_estrc = np.dot(B,fluxtilde2)
    img_recon = np.real(np.resize(img_estrc,(d1,d0)).T)
    plt.figure("Residuals of 2D fit")
    plt.imshow(np.vstack((ccd_small,img_recon,ccd_small-img_recon)),interpolation='none')
    chi_red = np.sum((ccd_small-img_recon)[:,fit_pad:-fit_pad]**2*ccd_small_invar[:,fit_pad:-fit_pad])/(np.size(ccd_small[:,fit_pad:-fit_pad])-jj+1)
    print("Reduced chi2 = {}".format(chi_red))
    #plt.figure()
    #plt.imshow(ccd_small,interpolation='none')
    plt.show()
    #img_raw = np.resize(np.dot(B,np.ones(len(fluxtilde2))),(d1,d0)).T
    #plt.imshow(img_raw,interpolation='none')
    #plt.show()
    plt.figure("Cross section of fit, residuals")
    for i in range(20,26):
    #    plt.plot(ccd_small[:,i])
        plt.plot(img_recon[:,i])
    #    plt.plot(final_centers[i,1],np.max(ccd_small[:,i]),'kd')
        plt.plot((ccd_small-img_recon)[:,i])#/np.sqrt(abs(ccd_small[:,i])))
    #    plt.show()
    #    plt.close()
    plt.show()
    plt.close()