コード例 #1
0
def fit_wave_soln(fnlist, doprint=False):
    """Fits a wavelength solution to rings in a set of images. Appends this
    wavelength solution to the image headers as the keywords:
    'fpwave0' and 'fpcalf'

    Each object in fnlist must have a corresponding "median.fits" in its
    image directory, or this routine will not work.

    ARC ring images are fitted by adjusting the center, while the center is
    held fixed for night sky rings. A combination of fits to both sets of
    rings is used to determine a wavelength solution for the whole set of
    images.

    If the ARC rings disagree substantially with the night sky rings, it is
    recommended that users delete the ARC rings from the fit and use only the
    night sky rings.

    It is also known that the wavelength solution can sometimes be piecewise in
    time when a large jump in 'z' happens between two images; i.e. different
    wavelength solutions exist before and after the jump. The routine allows
    the user to make a piecewise solution for this reason, but this ability
    should be used sparingly.

    This routine contains one of the few hard-coded numbers in the pipeline,
    Fguess=5600. Currently F values are not written to the fits image headers,
    and this is a reasonable guess.

    Inputs:
    fnlist -> List of strings, each the path to a fits image. These images
    should all have been taken with the same order filter. If not, the routine
    will crash.

    """

    # This bit takes care of the 's' to save shortcut in matplotlib.
    oldsavekey = plt.rcParams["keymap.save"]
    plt.rcParams["keymap.save"] = ""

    # Open all of the images
    arclist = []
    objlist = []
    images = [FPImage(fn) for fn in fnlist]

    # Separate ARCs and Object images and median-subtract Object images
    isarclist = [image.object == "ARC" for image in images]
    for i in range(len(isarclist)):
        if isarclist[i]:
            arclist.append(images[i])
        else:
            if not isfile(join(split(fnlist[i])[0], "median.fits")):
                exit("Error! No 'median.fits' file found.")
            medimage = fits.open(join(split(fnlist[i])[0], "median.fits"))
            images[i].inty -= medimage[0].data
            images[i].inty -= np.median(images[i].inty[images[i].badp != 1])
            medimage.close()
            objlist.append(images[i])

    filts = [image.filter for image in images]
    arclibs = []
    nightlibs = []
    for i in range(len(fnlist)):
        if isarclist[i]:
            arclibs.append(get_libraries(filts[i])[0])
        else:
            nightlibs.append(get_libraries(filts[i])[1])
        if get_libraries(filts[i])[0] is None:
            exit("Error! Filter "+filts[i]+" not in the wavelength library!")

    # This next bit fits all of the rings that the user marks

    # Fit rings in the object images
    radlists = []
    for i in range(len(objlist)):
        radlists.append([])
    i = 0
    while True:
        xcen = objlist[i].xcen
        ycen = objlist[i].ycen
        axcen = objlist[i].axcen
        aycen = objlist[i].aycen
        arad = objlist[i].arad
        rgrid = objlist[i].rarray(axcen, aycen)
        # Create radius bins
        rbins = np.arange(arad-np.int(max(abs(axcen-xcen), abs(aycen-ycen))))+1
        intbins = np.empty_like(rbins)
        # Get the median intensity within each radius bin
        for j in range(len(rbins)):
            binmask = np.logical_and(rgrid > rbins[j]-1,
                                     rgrid < rbins[j])
            goodbinmask = np.logical_and(binmask,
                                         objlist[i].badp == 0)
            if np.sum(goodbinmask) != 0:
                intbins[j] = np.median(objlist[i].inty[goodbinmask])
            else:
                intbins[j] = 0
        # Shift/scale the radius and intensity for the purposes of plotting
        plotxcen = xcen-axcen+arad
        plotycen = ycen-aycen+arad
        plotrbins = rbins+plotxcen
        plotintbins = intbins*arad/np.percentile(np.abs(intbins), 98)+plotycen

        # Plot the data interactively
        ringplot = PlotRingProfile(objlist[i].inty[aycen-arad:aycen+arad,
                                                   axcen-arad:axcen+arad],
                                   plotrbins, plotintbins,
                                   plotxcen, plotycen,
                                   radlists[i],
                                   repr(i+1)+"/"+repr(len(objlist)))

        # Changing images and loop breakout conditions
        if ringplot.key == "d":
            i += 1
        if ringplot.key == "a":
            i += -1
        if i == -1 or i == len(objlist):
            while True:
                yn = raw_input("Finished marking sky rings? (y/n) ")
                if "n" in yn or "N" in yn:
                    if i == -1:
                        i = 0
                    if i == len(objlist):
                        i = len(objlist)-1
                    break
                elif "y" in yn or "Y" in yn:
                    break
        if i == -1 or i == len(objlist):
            break

        # Force-marking a ring
        if ringplot.key == "e" and ringplot.xcoo is not None:
            radlists[i].append(ringplot.xcoo-arad-(xcen-axcen))

        # Deleting a ring
        if (ringplot.key == "s" and ringplot.xcoo is not None and
                len(radlists[i]) > 0):
            distarray = np.abs((np.array(radlists[i]) -
                                np.sqrt((ringplot.xcoo-arad-(xcen-axcen))**2 +
                                        (ringplot.ycoo-arad-(ycen-aycen))**2)))
            radlists[i].pop(np.argmin(distarray))

        # Fitting a ring profile
        if ringplot.key == "w" and ringplot.xcoo is not None:
            lower_index = max(ringplot.xcoo-plotxcen-50, 0)
            upper_index = min(ringplot.xcoo-plotxcen+50, len(rbins))
            x = rbins[lower_index:upper_index]**2
            y = intbins[lower_index:upper_index]
            fit = GaussFit(x, y)
            fitplot = PlotRingFit(x, y, fit)
            if fitplot.key == "w":
                radlists[i].append(np.sqrt(fit[2]))

    zo = []
    to = []
    ro = []
    lib_o = []
    for i in range(len(objlist)):
        for j in range(len(radlists[i])):
            zo.append(objlist[i].z)
            to.append(objlist[i].jd)
            ro.append(radlists[i][j])
            lib_o.append(nightlibs[i])
            if doprint:
                print objlist[i].z, objlist[i].jd, radlists[i][j]

    # Fit rings in the ARC images if there are any
    xcen = objlist[0].xcen
    ycen = objlist[0].ycen
    radlists = []
    for i in range(len(arclist)):
        radlists.append([])
    i = 0
    while len(arclist) > 0:
        axcen = arclist[i].axcen
        aycen = arclist[i].aycen
        arad = arclist[i].arad
        rgrid = arclist[i].rarray(axcen, aycen)
        # Create radius bins
        rbins = np.arange(arad-np.int(max(abs(axcen-xcen), abs(aycen-ycen))))+1
        intbins = np.empty_like(rbins)
        # Get the median intensity in each radius bin
        for j in range(len(rbins)):
            binmask = np.logical_and(rgrid > rbins[j]-1,
                                     rgrid < rbins[j])
            goodbinmask = np.logical_and(binmask,
                                         arclist[i].badp == 0)
            intbins[j] = np.median(arclist[i].inty[goodbinmask])
        # Shift/scale the radius and intensity for the purposes of plotting
        plotxcen = xcen-axcen+arad
        plotycen = ycen-aycen+arad
        plotrbins = rbins+plotxcen
        plotintbins = intbins*arad/np.percentile(np.abs(intbins), 98)+plotycen

        # Plot the data interactively
        ringplot = PlotRingProfile(arclist[i].inty[aycen-arad:aycen+arad,
                                                   axcen-arad:axcen+arad],
                                   plotrbins, plotintbins,
                                   plotxcen, plotycen,
                                   radlists[i],
                                   repr(i+1)+"/"+repr(len(arclist)))

        # Changing images and loop breakout conditions
        if ringplot.key == "d":
            i += 1
        if ringplot.key == "a":
            i += -1
        if i == -1 or i == len(arclist):
            while True:
                yn = raw_input("Finished marking ARC rings? (y/n) ")
                if "n" in yn or "N" in yn:
                    if i == -1:
                        i = 0
                    if i == len(arclist):
                        i = len(arclist)-1
                    break
                elif "y" in yn or "Y" in yn:
                    break
        if i == -1 or i == len(arclist):
            break

        # Force-marking a ring
        if ringplot.key == "e" and ringplot.xcoo is not None:
            radlists[i].append(ringplot.xcoo-arad-(xcen-axcen))

        # Deleting a ring
        if (ringplot.key == "s" and ringplot.xcoo is not None and
                len(radlists[i]) > 0):
            distarray = np.abs((np.array(radlists[i]) -
                                np.sqrt((ringplot.xcoo-arad-(xcen-axcen))**2 +
                                        (ringplot.ycoo-arad-(ycen-aycen))**2)))
            radlists[i].pop(np.argmin(distarray))
        # Fitting a ring profile
        if ringplot.key == "w" and ringplot.xcoo is not None:
            lower_index = max(ringplot.xcoo-plotxcen-50, 0)
            upper_index = min(ringplot.xcoo-plotxcen+50, len(rbins))
            x = rbins[lower_index:upper_index]**2
            y = intbins[lower_index:upper_index]
            fit = GaussFit(x, y)
            fitplot = PlotRingFit(x, y, fit)
            if fitplot.key == "w":
                radlists[i].append(np.sqrt(fit[2]))

    za = []
    ta = []
    ra = []
    lib_a = []
    for i in range(len(arclist)):
        for j in range(len(radlists[i])):
            za.append(arclist[i].z)
            ta.append(arclist[i].jd)
            ra.append(radlists[i][j])
            lib_a.append(arclibs[i])
            if doprint:
                print arclist[i].z, arclist[i].jd, radlists[i][j]

    # Now we try to get a good guess at the wavelengths

    # Get a good guess at which wavelengths are which
    Bguess = objlist[0].b
    Fguess = objlist[0].f
    if Fguess is None:
        Fguess = 5600

    # Figure out A by matching rings to the wavelength libraries
    master_r = np.array(ro+ra)
    master_z = np.array(zo+za)
    wavematch = np.zeros_like(master_r)
    oldrms = 10000  # Really high initial RMS for comparisons
    master_lib = lib_o+lib_a
    for i in range(len(master_r)):
        lib = master_lib[i]
        for j in range(len(lib)):
            # Assume the i'th ring is the j'th line
            Aguess = (lib[j]*np.sqrt(1+master_r[i]**2/Fguess**2) -
                      Bguess*master_z[i])
            # What are all of the other rings, given this A?
            waveguess = ((Aguess+Bguess*master_z) /
                         np.sqrt(1+master_r**2/Fguess**2))
            for k in range(len(master_r)):
                wherematch = np.argmin(np.abs(master_lib[k]-waveguess[k]))
                wavematch[k] = master_lib[k][wherematch]
            rms = np.sqrt(np.average((waveguess-wavematch)**2))
            if rms < oldrms:
                # This is the new best solution. Keep it!
                oldrms = rms
                bestA = Aguess
                master_wave = wavematch.copy()

    # Make more master arrays for the plotting
    master_t = np.array(to+ta)
    t0 = np.min(master_t)
    master_t += -t0
    master_t *= 24*60  # Convert to minutes
    master_color = np.array(len(ro)*["blue"]+len(ra)*["red"])
    toggle = np.ones(len(master_r), dtype="bool")
    dotime = False
    time_dividers = []

    # Do the interactive plotting
    while True:
        rplot = master_r[toggle]
        zplot = master_z[toggle]
        tplot = master_t[toggle]
        colorplot = master_color[toggle]
        wplot = master_wave[toggle]
        fitplot = np.zeros(len(wplot))
        xs = np.zeros((3, len(rplot)))
        xs[0] = rplot
        xs[1] = zplot
        xs[2] = tplot
        fit = [0]*(len(time_dividers)+1)
        time_dividers = sorted(time_dividers)
        if len(time_dividers) > 1:
            print ("Warning: Too many time divisions is likely unphysical." +
                   "Be careful!")
        for i in range(len(time_dividers)+1):
            # Create a slice for all of the wavelengths before this time
            # divider but after the one before it
            if len(time_dividers) == 0:
                tslice = tplot == tplot
            elif i == 0:
                tslice = tplot < time_dividers[i]
            elif i == len(time_dividers):
                tslice = tplot > time_dividers[i-1]
            else:
                tslice = np.logical_and(tplot < time_dividers[i],
                                        tplot > time_dividers[i-1])
            # Case for fitting time dependence
            if dotime:
                fit[i] = curve_fit(fpfunc_for_curve_fit_with_t,
                                   xs[:, tslice], wplot[tslice],
                                   p0=(bestA, Bguess, 0, Fguess))[0]
                fitplot[tslice] = fpfunc_for_curve_fit_with_t(xs[:, tslice],
                                                              fit[i][0],
                                                              fit[i][1],
                                                              fit[i][2],
                                                              fit[i][3])
            # Case without time dependence
            else:
                fit[i] = curve_fit(fpfunc_for_curve_fit,
                                   xs[:, tslice], wplot[tslice],
                                   p0=(bestA, Bguess, Fguess))[0]
                fitplot[tslice] = fpfunc_for_curve_fit(xs[:, tslice],
                                                       fit[i][0],
                                                       fit[i][1],
                                                       fit[i][2])
        # Calculate residuals to the fit
        resid = wplot - fitplot

        # Interactively plot the residuals
        solnplot = WaveSolnPlot(rplot, zplot, tplot, wplot,
                                resid, colorplot, time_dividers)
        # Breakout case
        if solnplot.key == "a":
            while True:
                for i in range(len(time_dividers)+1):
                    if dotime:
                        solnstring = ("Solution "+repr(i+1) +
                                      ": A = "+str(fit[i][0]) +
                                      ", B = "+str(fit[i][1]) +
                                      ", E = "+str(fit[i][2]) +
                                      ", F = "+str(fit[i][3]))
                    else:
                        solnstring = ("Solution "+repr(i+1) +
                                      ": A = "+str(fit[i][0]) +
                                      ", B = "+str(fit[i][1]) +
                                      ", F = "+str(fit[i][2]))
                    print solnstring
                rms = np.sqrt(np.average(resid**2))
                print ("Residual rms="+str(rms) +
                       " for "+repr(len(time_dividers)+1) +
                       " independent "+repr(3+dotime) +
                       "-parameter fits to "+repr(len(rplot))+" rings.")
                yn = raw_input("Accept wavelength solution? (y/n) ")
                if "n" in yn or "N" in yn:
                    break
                elif "y" in yn or "Y" in yn:
                    solnplot.key = "QUIT"
                    break
        if solnplot.key == "QUIT":
            break

        # Restore all points case
        if solnplot.key == "r":
            toggle = np.ones(len(master_r), dtype="bool")

        # Delete nearest point case
        if solnplot.key == "d" and solnplot.axis is not None:
            # Figure out which plot was clicked in
            if solnplot.axis == 1:
                # Resid vs. z plot
                z_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = (((zplot-z_loc)/(np.max(zplot)-np.min(zplot)))**2 +
                         ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2)
            elif solnplot.axis == 2:
                # Resid vs. R plot
                r_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = (((rplot-r_loc)/(np.max(rplot)-np.min(rplot)))**2 +
                         ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2)
            elif solnplot.axis == 3:
                # Resit vs. T plot
                t_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = (((tplot-t_loc)/(np.max(tplot)-np.min(tplot)))**2 +
                         ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2)
            elif solnplot.axis == 4:
                # Resid vs. Wave plot
                wave_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = (((wplot-wave_loc)/(np.max(wplot)-np.min(wplot)))**2 +
                         ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2)
            # Get the radius and time of the worst ring
            r_mask = rplot[dist2 == np.min(dist2)][0]
            t_mask = tplot[dist2 == np.min(dist2)][0]
            toggle[np.logical_and(master_r == r_mask,
                                  master_t == t_mask)] = False

        # Toggle time fit case
        if solnplot.key == "t":
            dotime = not dotime

        # Add time break case
        if solnplot.key == "w":
            timeplot = TimePlot(tplot, resid, colorplot, time_dividers)
            if timeplot.xcoo is not None:
                time_dividers.append(timeplot.xcoo)

        # Remove time breaks case
        if solnplot.key == "q":
            time_dividers = []

    # Close all images
    for image in images:
        image.close()

    # For each image, write the central wavelength and F to the image header
    for i in range(len(fnlist)):
        image = FPImage(fnlist[i], update=True)
        image_t = (image.jd-t0)*24*60
        # Figure out which time division it's in
        div_index = np.where(np.array(time_dividers) > image_t)[0]
        if len(div_index > 0):
            div_index = div_index[0]
        else:
            div_index = len(time_dividers)
        image_fit = fit[div_index]
        if dotime:
            image_wave0 = (image_fit[0] +
                           image_fit[1]*image.z +
                           image_fit[2]*image_t)
            image_F = image_fit[3]
        else:
            image_wave0 = (image_fit[0] +
                           image_fit[1]*image.z)
            image_F = image_fit[2]
        image.wave0 = image_wave0
        image.calf = image_F
        image.calrms = rms
        image.calnring = repr(len(rplot))
        image.calnfits = repr(len(time_dividers)+1)
        image.calnpars = repr(3+dotime)
        image.close()

    # Restore the old keyword shortcut
    plt.rcParams["keymap.save"] = oldsavekey

    return
コード例 #2
0
def sub_sky_rings(fnlist,medfilelist):
    """Fits for night sky rings in a series of images, then subtracts them off.
    If uncertainty images exist, updates them with better uncertainties.
    
    The rings are fitted in an interactive way by the user.
    
    After the ring subtraction, the header keyword "fpdering" is created and
    set to "True."
    
    Note: A 'median.fits' image must exist for the object or this routine will
    not work.
    
    Inputs:
    fnlist -> List of strings, each the path to a fits image.
    medfilelist -> List of paths to median images for each image to subtract off
                   (probably all the same, but generally could be different)
    
    """
    
    #This bit takes care of the 's' to save shortcut in matplotlib.
    oldsavekey = plt.rcParams["keymap.save"]
    plt.rcParams["keymap.save"] = ""
    
    #Open all of the images, median-subtract them, make wavelength arrays
    # make skywavelibs for each, make spectra for each, trim images
    imagelist = []
    wavearraylist = []
    skywavelibs = []
    skywavelibs_extended = []
    spectrum_wave = []
    spectrum_inty = []
    minwavelist = []
    maxwavelist = []
    xcenlist = []
    ycenlist = []
    Flist = []
    wave0list = []
    has_been_saved = np.zeros(len(fnlist),dtype="bool")
    for i in range(len(fnlist)):
        #Open image
        imagelist.append(openfits(fnlist[i]))
        medimage = openfits(medfilelist[i])
        #Get a bunch of relevant header values
        Flist.append(imagelist[i][0].header["fpcalf"])
        wave0list.append(imagelist[i][0].header["fpwave0"])
        xcen = imagelist[i][0].header["fpxcen"]
        ycen = imagelist[i][0].header["fpycen"]
        arad = imagelist[i][0].header["fparad"]
        axcen = imagelist[i][0].header["fpaxcen"]
        aycen = imagelist[i][0].header["fpaycen"]
        #Median subtract it
        imagelist[i][0].data += -medimage[0].data
        #Make wavelength array
        xgrid, ygrid = np.meshgrid(np.arange(imagelist[i][0].data.shape[1]),np.arange(imagelist[i][0].data.shape[0]))
        r2grid = (xgrid-xcen)**2+(ygrid-ycen)**2
        wavearraylist.append(wave0list[i]/np.sqrt(1+r2grid/Flist[i]**2))
        #Get the sky wave library
        minwavelist.append(np.min(wavearraylist[i][r2grid<arad**2]))
        maxwavelist.append(np.max(wavearraylist[i][r2grid<arad**2]))
        skywaves = get_libraries(imagelist[i][0].header["FILTER"])[1]
        skywavelibs.append(skywaves[np.logical_and(skywaves>minwavelist[i],skywaves<maxwavelist[i])])
        #Make an "extended" sky wave library
        skywavelibs_extended.append(skywaves[np.logical_or(np.logical_and(skywaves<minwavelist[i],skywaves>minwavelist[i]-5),np.logical_and(skywaves>maxwavelist[i],skywaves<maxwavelist[i]+5))])
        #Make the spectrum
        spectrum_wave.append(np.linspace(np.min(wavearraylist[i][r2grid<arad**2]),np.max(wavearraylist[i][r2grid<arad**2]),np.int(((np.max(wavearraylist[i][r2grid<arad**2])-np.min(wavearraylist[i][r2grid<arad**2]))/0.25))))
        spectrum_inty.append(np.zeros_like(spectrum_wave[i][1:]))
        for j in range(len(spectrum_wave[i])-1):
            spectrum_inty[i][j] = np.median(imagelist[i][0].data[np.logical_and(np.logical_and(imagelist[i][0].data!=0,wavearraylist[i]>spectrum_wave[i][j]),wavearraylist[i]<spectrum_wave[i][j+1])])
        spectrum_wave[i] = 0.5*(spectrum_wave[i][:-1]+spectrum_wave[i][1:])
        #Trim the images and arrays to the aperture size
        imagelist[i][0].data = imagelist[i][0].data[aycen-arad:aycen+arad,axcen-arad:axcen+arad]
        wavearraylist[i] = wavearraylist[i][aycen-arad:aycen+arad,axcen-arad:axcen+arad]
        xcenlist.append(arad+xcen-axcen)
        ycenlist.append(arad+ycen-aycen)
        #Convert skywavelibs to lists
        skywavelibs[i] = list(skywavelibs[i])
        skywavelibs_extended[i] = list(skywavelibs_extended[i])
        #Close median image
        medimage.close()
    
    #Interactive plotting of ring profiles
    addedwaves = []
    final_fitted_waves = []
    final_fitted_intys = []
    final_fitted_sigs = []
    for i in range(len(fnlist)):
        addedwaves.append([])
        final_fitted_waves.append([])
        final_fitted_intys.append([])
        final_fitted_sigs.append([])
    i = 0
    while True:
        #Determine which wavelengths we want to fit
        waves_to_fit = skywavelibs[i]+addedwaves[i]
        
        #Generate a fitting function from those wavelengths
        func_to_fit = make_sum_gaussians(waves_to_fit)
        
        #Come up with reasonable guesses for the fitting function parameters
        contguess = [0]
        intyguesses = []
        sigguesses = []
        for j in range(len(waves_to_fit)):
            intyguesses.append(np.sum(spectrum_inty[i][np.logical_and(spectrum_wave[i]<waves_to_fit[j]+4,spectrum_wave[i]>waves_to_fit[j]-4)]*(spectrum_wave[i][1]-spectrum_wave[i][0])))
            sigguesses.append(2)
        guess = contguess+intyguesses+sigguesses
        
        #Fit the spectrum
        fitsuccess = True
        try:
            fit = curve_fit(func_to_fit,spectrum_wave[i],spectrum_inty[i],p0=guess)[0]
        except RuntimeError:
            print "Warning: Fit did not converge. Maybe remove some erroneous lines from the fit?"
            fit = guess
            fitsuccess = False
        fitcont = fit[0]
        fitintys = np.array(fit[1:len(waves_to_fit)+1])
        fitsigs = np.array(fit[len(waves_to_fit)+1:2*len(waves_to_fit)+1])
        
        #Make the subtracted plot array
        subarray = imagelist[i][0].data.copy()
        for j in range(len(waves_to_fit)): subarray += -fitintys[j]*nGauss(wavearraylist[i]-waves_to_fit[j],fitsigs[j])
        
        #Figure out which radius each wavelength is at
        radiilist = []
        for j in range(len(waves_to_fit)): radiilist.append(Flist[i]*np.sqrt((wave0list[i]/waves_to_fit[j])**2-1))
        
        #Create the subtracted spectrum
        sub_spec_inty = spectrum_inty[i] - fitcont
        for j in range(len(waves_to_fit)): sub_spec_inty += -fitintys[j]*nGauss(spectrum_wave[i]-waves_to_fit[j],fitsigs[j])
        
        #Create the spectrum fit plot
        fit_X = np.linspace(minwavelist[i],maxwavelist[i],500)
        fit_Y = np.ones_like(fit_X)*fitcont
        for j in range(len(waves_to_fit)): fit_Y += fitintys[j]*nGauss(fit_X-waves_to_fit[j], fitsigs[j])
        
        #Plot the image, spectrum, and fit
        profile_plot = SkyRingPlot(imagelist[i][0].data, subarray,
                                   xcenlist[i], ycenlist[i], radiilist,
                                   spectrum_wave[i], spectrum_inty[i], sub_spec_inty,
                                   fit_X, fit_Y,
                                   skywavelibs[i], addedwaves[i], skywavelibs_extended[i],
                                   repr(i+1)+"/"+repr(len(fnlist)), has_been_saved[i], fitsuccess)
        
        #Shifting images and breakout condition
        if profile_plot.key == "a": i+=-1
        if profile_plot.key == "d": i+=1
        quitloop = False
        if (i==-1 or i==len(fnlist) or profile_plot.key=="q"):
            if (np.sum(np.logical_not(has_been_saved))==0):
                while True:
                    yn = raw_input("Finished fitting ring profiles? (y/n) ")
                    if "n" in yn or "N" in yn:
                        break
                    if "y" in yn or "Y" in yn:
                        quitloop=True
                        break
            else:
                print "Error: Ring fits have not yet been saved for the following images:"
                print (np.arange(len(fnlist))+1)[np.logical_not(has_been_saved)]
        if quitloop: break
        if i == -1: i=0
        if i == len(fnlist): i=len(fnlist)-1
        
        #Delete ring option
        if profile_plot.key == "s":
            nearest_wave = None
            if profile_plot.axis in [1,3]:
                #The click was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 + (profile_plot.ycoo - ycenlist[i])**2)
                nearest_wave = waves_to_fit[np.argmin(np.abs((np.array(radiilist)-clicked_radius)))]
            elif profile_plot.axis in [2,4]:
                #The click was in a spectrum plot
                clicked_wave = profile_plot.xcoo
                nearest_wave = waves_to_fit[np.argmin(np.abs(np.array(waves_to_fit)-clicked_wave))]
            if nearest_wave != None:
                #Remove the nearest wavelength from the sky_wave_lib or added waves, add it to extra waves
                if nearest_wave in skywavelibs[i]:
                    skywavelibs[i].remove(nearest_wave)
                    skywavelibs_extended[i].append(nearest_wave)
                if nearest_wave in addedwaves[i]: addedwaves[i].remove(nearest_wave)
                
        #Add ring option
        if profile_plot.key == "w":
            clicked_wave = None
            if profile_plot.axis in [1,3]:
                #The click was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 + (profile_plot.ycoo - ycenlist[i])**2)
                #Convert radius to wavelength
                clicked_wave = wave0list[i] / (1+clicked_radius**2/Flist[i]**2)**0.5
            elif profile_plot.axis in [2,4]:
                #The click was in a spectrum plot
                clicked_wave = profile_plot.xcoo
            if clicked_wave != None:
                #Find the nearest wavelength in the extended list (if there are any)
                if len(np.abs(np.array(skywavelibs_extended[i])-clicked_wave))>0:
                    wave_to_add = skywavelibs_extended[i][np.argmin(np.abs(np.array(skywavelibs_extended[i])-clicked_wave))]
                    #Add that wave to skywavelibs, remove it from extended
                    skywavelibs[i].append(wave_to_add)
                    skywavelibs_extended[i].remove(wave_to_add)
        
        #Force add ring option
        if profile_plot.key == "e":
            clicked_wave = None
            if profile_plot.axis in [1,3]:
                #The click was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 + (profile_plot.ycoo - ycenlist[i])**2)
                #Convert radius to wavelength
                clicked_wave = wave0list[i] / (1+clicked_radius**2/Flist[i]**2)**0.5
            elif profile_plot.axis in [2,4]:
                #The click was in a spectrum plot
                clicked_wave = profile_plot.xcoo
            if clicked_wave != None:
                #Add this wavelength to the added array
                addedwaves[i].append(clicked_wave)
                
        #Save option
        if profile_plot.key == "r":
            final_fitted_waves[i] = waves_to_fit[:]
            final_fitted_intys[i] = fitintys[:]
            final_fitted_sigs[i] = fitsigs[:]
            has_been_saved[i] = True
        
    #Close all of the images
    medimage.close()
    for i in range(len(fnlist)):
        imagelist[i].close()
    
    #Subtract the ring profiles and update headers
    for i in range(len(fnlist)):
        image = openfits(fnlist[i], mode="update")
        arad = image[0].header["fparad"]
        axcen = image[0].header["fpaxcen"]
        aycen = image[0].header["fpaycen"]
        mask = image[0].data == 0 #For re-correcting chip gaps
        for j in range(len(final_fitted_waves[i])):
            image[0].data[aycen-arad:aycen+arad,axcen-arad:axcen+arad] += -final_fitted_intys[i][j]*nGauss(wavearraylist[i]-final_fitted_waves[i][j], final_fitted_sigs[i][j])
            image[0].header["fpdering"] = "True"
        image[0].data[mask]=0 #For re-correcting chip gaps
        image.close()
    
    #Restore the old keyword shortcut
    plt.rcParams["keymap.save"] = oldsavekey
    
    return
コード例 #3
0
def sub_sky_rings(fnlist, medfilelist):
    """Fits for night sky rings in a series of images, then subtracts them off.
    If uncertainty images exist, updates them with better uncertainties.

    The rings are fitted in an interactive way by the user.

    After the ring subtraction, the header keyword "fpdering" is created and
    set to "True."

    Note: A 'median.fits' image must exist for the object or this routine will
    not work.

    Inputs:
    fnlist -> List of strings, each the path to a fits image.
    medfilelist -> List of paths to median images for each image to subtract
                   off (probably all the same, but could be different)

    """

    # This bit takes care of the 's' to save shortcut in matplotlib.
    oldsavekey = plt.rcParams["keymap.save"]
    plt.rcParams["keymap.save"] = ""
    oldfullscreenkey = plt.rcParams['keymap.fullscreen']
    plt.rcParams['keymap.fullscreen'] = ""

    # Open all of the images, median-subtract them, make wavelength arrays
    # make skywavelibs for each, make spectra for each, trim images
    images = [FPImage(fn) for fn in fnlist]
    wavearraylist = []
    skywavelibs = []
    skywavelibs_extended = []
    spectrum_wave = []
    spectrum_inty = []
    minwavelist = []
    maxwavelist = []
    xcenlist = []
    ycenlist = []
    Flist = []
    wave0list = []
    has_been_saved = np.zeros(len(fnlist), dtype="bool")
    for i in range(len(fnlist)):
        # Get a bunch of relevant header values
        medimage = fits.open(medfilelist[i])
        Flist.append(images[i].calf)
        wave0list.append(images[i].wave0)
        xcen = images[i].xcen
        ycen = images[i].ycen
        arad = images[i].arad
        axcen = images[i].axcen
        aycen = images[i].aycen
        # Median subtract it
        images[i].inty -= medimage[0].data
        images[i].inty -= np.median(images[i].inty[images[i].badp != 1])
        # Make wavelength array
        r2grid = images[i].rarray(xcen, ycen)**2
        wavearraylist.append(wave0list[i]/np.sqrt(1+r2grid/Flist[i]**2))
        # Get the sky wave library
        minwavelist.append(np.min(wavearraylist[i][r2grid < arad**2]))
        maxwavelist.append(np.max(wavearraylist[i][r2grid < arad**2]))
        skywaves = get_libraries(images[i].filter)[1]
        skywavelibs.append(skywaves[np.logical_and(skywaves > minwavelist[i],
                                                   skywaves < maxwavelist[i])])
        # Make an "extended" sky wave library
        lowermask = np.logical_and(skywaves < minwavelist[i],
                                   skywaves > minwavelist[i]-5)
        uppermask = np.logical_and(skywaves > maxwavelist[i],
                                   skywaves < maxwavelist[i]+5)
        skywavelibs_extended.append(skywaves[np.logical_or(uppermask,
                                                           lowermask)])
        # Make the spectrum
        minwav = np.min(wavearraylist[i][r2grid < arad**2])
        maxwav = np.max(wavearraylist[i][r2grid < arad**2])
        wavstep = 0.25  # Quarter-angstrom spacing
        spectrum_wave.append(np.linspace(minwav, maxwav,
                                         np.int(((maxwav-minwav)/wavstep))))
        spectrum_inty.append(np.zeros_like(spectrum_wave[i][1:]))
        for j in range(len(spectrum_wave[i])-1):
            uppermask = wavearraylist[i] < spectrum_wave[i][j+1]
            lowermask = wavearraylist[i] > spectrum_wave[i][j]
            mask = np.logical_and(uppermask, lowermask)
            goodmask = np.logical_and(mask, images[i].badp == 0)
            spectrum_inty[i][j] = np.median(images[i].inty[goodmask])
        spectrum_wave[i] = 0.5*(spectrum_wave[i][:-1]+spectrum_wave[i][1:])
        # Trim the images and arrays to the aperture size
        images[i].inty = images[i].inty[aycen-arad:aycen+arad,
                                        axcen-arad:axcen+arad]
        wavearraylist[i] = wavearraylist[i][aycen-arad:aycen+arad,
                                            axcen-arad:axcen+arad]
        xcenlist.append(arad+xcen-axcen)
        ycenlist.append(arad+ycen-aycen)
        # Convert skywavelibs to lists
        skywavelibs[i] = list(skywavelibs[i])
        skywavelibs_extended[i] = list(skywavelibs_extended[i])
        # Close median image
        medimage.close()
        # Try to fix the NaNs
        mask = np.logical_not(np.isnan(spectrum_inty[i]))
        spectrum_wave[i] = spectrum_wave[i][mask]
        spectrum_inty[i] = spectrum_inty[i][mask]

    # Interactive plotting of ring profiles
    addedwaves = []
    final_fitted_waves = []
    final_fitted_intys = []
    final_fitted_sigs = []
    for i in range(len(fnlist)):
        addedwaves.append([])
        final_fitted_waves.append([])
        final_fitted_intys.append([])
        final_fitted_sigs.append([])
    i = 0
    while True:
        # Determine which wavelengths we want to fit
        waves_to_fit = skywavelibs[i]+addedwaves[i]

        # Generate a fitting function from those wavelengths
        func_to_fit = make_sum_gaussians(waves_to_fit)

        # Come up with reasonable guesses for the fitting function parameters
        contguess = [0]
        intyguesses = []
        sigguesses = []
        for j in range(len(waves_to_fit)):
            uppermask = spectrum_wave[i] < waves_to_fit[j]+4
            lowermask = spectrum_wave[i] > waves_to_fit[j]-4
            mask = np.logical_and(uppermask, lowermask)
            # Guess at line intensity via integral
            intyguesses.append(np.sum(spectrum_inty[i][mask] *
                                      (spectrum_wave[i][1] -
                                       spectrum_wave[i][0])))
            # Guess sigma is 2 angstroms
            sigguesses.append(2)
        guess = contguess+intyguesses+sigguesses

        # Fit the spectrum
        fitsuccess = True
        try:
            fit = curve_fit(func_to_fit, spectrum_wave[i],
                            spectrum_inty[i], p0=guess)[0]
        except RuntimeError:
            print ("Warning: Fit did not converge. " +
                   "Maybe remove some erroneous lines from the fit?")
            fit = guess
            fitsuccess = False
        fitcont = fit[0]
        fitintys = np.array(fit[1:len(waves_to_fit)+1])
        fitsigs = np.array(fit[len(waves_to_fit)+1:2*len(waves_to_fit)+1])
        # Flip signs of negative sigma fits (sign degeneracy)
        fitintys[fitsigs < 0] = -1*fitintys[fitsigs < 0]
        fitsigs[fitsigs < 0] = -1*fitsigs[fitsigs < 0]

        # Make the subtracted plot array
        subarray = images[i].inty.copy()
        for j in range(len(waves_to_fit)):
            subarray -= fitintys[j]*nGauss(wavearraylist[i]-waves_to_fit[j],
                                           fitsigs[j])

        # Figure out which radius each wavelength is at
        radiilist = []
        for j in range(len(waves_to_fit)):
            radiilist.append(Flist[i] *
                             np.sqrt((wave0list[i]/waves_to_fit[j])**2-1))

        # Create the subtracted spectrum
        sub_spec_inty = spectrum_inty[i] - fitcont
        for j in range(len(waves_to_fit)):
            sub_spec_inty -= (fitintys[j] *
                              nGauss(spectrum_wave[i]-waves_to_fit[j],
                                     fitsigs[j]))

        # Create the spectrum fit plot
        fit_X = np.linspace(minwavelist[i], maxwavelist[i], 500)
        fit_Y = np.ones_like(fit_X)*fitcont
        for j in range(len(waves_to_fit)):
            fit_Y += fitintys[j]*nGauss(fit_X-waves_to_fit[j], fitsigs[j])

        # Plot the image, spectrum, and fit
        profile_plot = SkyRingPlot(images[i].inty, subarray,
                                   xcenlist[i], ycenlist[i], radiilist,
                                   spectrum_wave[i], spectrum_inty[i],
                                   sub_spec_inty, fit_X, fit_Y,
                                   skywavelibs[i], addedwaves[i],
                                   skywavelibs_extended[i],
                                   repr(i+1)+"/"+repr(len(fnlist)),
                                   has_been_saved[i], fitsuccess)

        # Shifting images and breakout condition
        if profile_plot.key == "a":
            i += -1
        if profile_plot.key == "d":
            i += 1
        quitloop = False
        if (i == -1 or i == len(fnlist) or profile_plot.key == "q"):
            if (np.sum(np.logical_not(has_been_saved)) == 0):
#                 while True:
#                     yn = raw_input("Finished fitting ring profiles? (y/n) ")
#                     if "n" in yn or "N" in yn:
#                         break
#                     if "y" in yn or "Y" in yn:
#                         quitloop = True
#                         break
                quitloop = True
            else:
                print ("Error: Ring fits have not yet been " +
                       "saved for the following images:")
                imagenumbers = np.arange(len(fnlist))+1
                print imagenumbers[np.logical_not(has_been_saved)]
        if quitloop:
            break
        if i == -1:
            i = 0
        if i == len(fnlist):
            i = len(fnlist)-1

        # Delete ring option
        if profile_plot.key == "s":
            nearest_wave = None
            if profile_plot.axis in [1, 3] and len(waves_to_fit) != 0:
                # The keypress was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 +
                                         (profile_plot.ycoo - ycenlist[i])**2)
                near_index = np.argmin(np.abs((np.array(radiilist) -
                                               clicked_radius)))
                nearest_wave = waves_to_fit[near_index]
            elif profile_plot.axis in [2, 4] and len(waves_to_fit) != 0:
                # The keypress was in a spectrum plot
                clicked_wave = profile_plot.xcoo
                near_index = np.argmin(np.abs(np.array(waves_to_fit) -
                                              clicked_wave))
                nearest_wave = waves_to_fit[near_index]
            if nearest_wave is not None:
                # Remove the nearest wavelength from the sky_wave_lib or
                # added waves, add it to extra waves
                if nearest_wave in skywavelibs[i]:
                    skywavelibs[i].remove(nearest_wave)
                    skywavelibs_extended[i].append(nearest_wave)
                if nearest_wave in addedwaves[i]:
                    addedwaves[i].remove(nearest_wave)

        # Add ring option
        if profile_plot.key == "w":
            clicked_wave = None
            if profile_plot.axis in [1, 3]:
                # The keypress was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 +
                                         (profile_plot.ycoo - ycenlist[i])**2)
                # Convert radius to wavelength
                clicked_wave = (wave0list[i] /
                                (1+clicked_radius**2/Flist[i]**2)**0.5)
            elif profile_plot.axis in [2, 4]:
                # The keypress was in a spectrum plot
                clicked_wave = profile_plot.xcoo
            if clicked_wave is not None:
                # Find the nearest wavelength in the extended list
                if len(np.array(skywavelibs_extended[i])-clicked_wave) > 0:
                    near = np.argmin(np.abs(np.array(skywavelibs_extended[i]) -
                                            clicked_wave))
                    wave_to_add = skywavelibs_extended[i][near]
                    # Add that wave to skywavelibs, remove it from extended
                    skywavelibs[i].append(wave_to_add)
                    skywavelibs_extended[i].remove(wave_to_add)

        # Force add ring option
        if profile_plot.key == "e":
            clicked_wave = None
            if profile_plot.axis in [1, 3]:
                # The click was made in an image plot
                clicked_radius = np.sqrt((profile_plot.xcoo - xcenlist[i])**2 +
                                         (profile_plot.ycoo - ycenlist[i])**2)
                # Convert radius to wavelength
                clicked_wave = (wave0list[i] /
                                (1+clicked_radius**2/Flist[i]**2)**0.5)
            elif profile_plot.axis in [2, 4]:
                # The click was in a spectrum plot
                clicked_wave = profile_plot.xcoo
            if clicked_wave is not None:
                # Add this wavelength to the added array
                addedwaves[i].append(clicked_wave)

        # Save option
        if profile_plot.key == "r":
            final_fitted_waves[i] = waves_to_fit[:]
            final_fitted_intys[i] = fitintys[:]
            final_fitted_sigs[i] = fitsigs[:]
            has_been_saved[i] = True

        # Manual Mode Option
        if profile_plot.key == "x":

            # Initialize the manual mode lists
            man_waves = list(waves_to_fit)
            man_intys = list(fitintys)
            man_sigs = list(fitsigs)
            man_cont = fitcont
            man_real = []
            for wave in man_waves:
                if wave in skywavelibs[i] or wave in skywavelibs_extended[i]:
                    man_real.append(True)
                else:
                    man_real.append(False)

            # Initialize refinement level and selected line
            selected_line = 0
            refine = 1
            inty_refine = 0.25*(np.max(spectrum_inty[i]) -
                                np.min(spectrum_inty[i]))
            sig_refine = 0.5
            cont_refine = inty_refine

            # Loop until completion
            while True:

                # Make the subtracted plot array
                subarray = images[i].inty.copy()
                for j in range(len(man_waves)):
                    subarray -= man_intys[j]*nGauss((wavearraylist[i] -
                                                     man_waves[j]),
                                                    man_sigs[j])

                # Figure out which radius each wavelength is at
                radiilist = []
                for j in range(len(man_waves)):
                    radiilist.append(Flist[i] *
                                     np.sqrt((wave0list[i]/man_waves[j])**2-1))
                for skywav in skywavelibs_extended[i]:
                    radiilist.append(Flist[i] *
                                     np.sqrt((wave0list[i] / skywav)**2-1))

                # Create the subtracted spectrum
                sub_spec_inty = spectrum_inty[i] - man_cont
                for j in range(len(man_waves)):
                    sub_spec_inty -= (man_intys[j] *
                                      nGauss(spectrum_wave[i]-man_waves[j],
                                             man_sigs[j]))

                # Create the spectrum fit plot
                fit_X = np.linspace(minwavelist[i], maxwavelist[i], 500)
                fit_Y = np.ones_like(fit_X)*man_cont
                for j in range(len(man_waves)):
                    fit_Y += man_intys[j]*nGauss(fit_X-man_waves[j],
                                                 man_sigs[j])

                # What colors are we using?
                colors = []
                for k in range(len(man_waves)):
                    if selected_line == k:
                        colors.append('purple')
                    elif not man_real[k]:
                        colors.append('green')
                    elif man_real[k]:
                        colors.append('blue')
                for _w in skywavelibs_extended[i]:
                    colors.append('red')

                # Plot the spectrum and fit, get user input
                man_plot = ManualPlot(images[i].inty, subarray,
                                      xcenlist[i], ycenlist[i],
                                      spectrum_wave[i], spectrum_inty[i],
                                      sub_spec_inty,
                                      fit_X, fit_Y,
                                      has_been_saved[i], (repr(i+1)+"/" +
                                                          repr(len(fnlist))),
                                      radiilist,
                                      man_waves+skywavelibs_extended[i],
                                      colors)

                # Cycle selected line
                if man_plot.key == "z":
                    selected_line += 1
                    if selected_line >= len(man_waves):
                        selected_line = 0
                    refine = 1

                # Save fit
                if man_plot.key == "r":
                    final_fitted_waves[i] = man_waves[:]
                    final_fitted_intys[i] = man_intys[:]
                    final_fitted_sigs[i] = man_sigs[:]
                    has_been_saved[i] = True
                    break

                # Refinement levels
                if man_plot.key == "c":
                    refine *= 0.5
                if man_plot.key == "v":
                    refine *= 2

                # Increase/decrease inty
                if len(man_waves) != 0:
                    if man_plot.key == "w":
                        man_intys[selected_line] += refine*inty_refine
                    if man_plot.key == "s":
                        man_intys[selected_line] -= refine*inty_refine

                # Increase/decrease line width
                if len(man_waves) != 0:
                    if man_plot.key == "q":
                        man_sigs[selected_line] += refine*sig_refine
                    if man_plot.key == "a":
                        man_sigs[selected_line] -= refine*sig_refine

                # Increase/decrease continuum
                if man_plot.key == "t":
                    man_cont += refine*cont_refine
                if man_plot.key == "g":
                    man_cont -= refine*cont_refine

                # Add the known line nearest the keypress
                if man_plot.key == "e":
                    clicked_wave = None
                    if man_plot.axis in [1, 3]:
                        # The keypress was made in an image plot
                        clicked_radius = np.sqrt((man_plot.xcoo -
                                                  xcenlist[i])**2 +
                                                 (man_plot.ycoo -
                                                  ycenlist[i])**2)
                        # Convert radius to wavelength
                        clicked_wave = (wave0list[i] /
                                        (1+clicked_radius**2/Flist[i]**2)**0.5)
                    elif man_plot.axis in [2, 4]:
                        # The keypress was in a spectrum plot
                        clicked_wave = man_plot.xcoo
                    if clicked_wave is not None:
                        # Find the nearest wavelength in the extended list
                        if len(np.array(skywavelibs_extended[i]) -
                               clicked_wave) > 0:
                            sky = np.array(skywavelibs_extended[i])
                            near = np.argmin(np.abs(sky-clicked_wave))
                            wave_to_add = skywavelibs_extended[i][near]
                            # Add that wave to manwaves,
                            # remove from extended list
                            man_waves.append(wave_to_add)
                            man_intys.append(0)
                            man_sigs.append(1)
                            man_real.append(True)
                            skywavelibs_extended[i].remove(wave_to_add)
                            selected_line = len(man_waves)-1

                # Remove line nearest the keypress
                if man_plot.key == "d":
                    nearest_wave = None
                    if man_plot.axis in [1, 3] and len(man_waves) != 0:
                        # The keypress was made in an image plot
                        clicked_radius = np.sqrt((man_plot.xcoo -
                                                  xcenlist[i])**2 +
                                                 (man_plot.ycoo -
                                                  ycenlist[i])**2)
                        near_index = np.argmin(np.abs((np.array(radiilist) -
                                                       clicked_radius)))
                        nearest_wave = man_waves[near_index]
                    elif man_plot.axis in [2, 4] and len(man_waves) != 0:
                        # The keypress was in a spectrum plot
                        clicked_wave = man_plot.xcoo
                        near_index = np.argmin(np.abs(np.array(man_waves) -
                                                      clicked_wave))
                        nearest_wave = man_waves[near_index]
                    if nearest_wave is not None:
                        # Remove the nearest wavelength
                        # re-add to extended list
                        # if it was real
                        wave_index = np.where(np.array(man_waves) ==
                                              nearest_wave)[0][0]
                        man_waves.pop(wave_index)
                        man_intys.pop(wave_index)
                        man_sigs.pop(wave_index)
                        if man_real[wave_index]:
                            skywavelibs_extended[i].append(nearest_wave)
                        man_real.pop(wave_index)
                        if wave_index == selected_line:
                            selected_line = 0

                # Force-add
                if man_plot.key == "f":
                    clicked_wave = None
                    if profile_plot.axis in [1, 3]:
                        # The click was made in an image plot
                        clicked_radius = np.sqrt((man_plot.xcoo -
                                                  xcenlist[i])**2 +
                                                 (man_plot.ycoo -
                                                  ycenlist[i])**2)
                        # Convert radius to wavelength
                        clicked_wave = (wave0list[i] /
                                        (1+clicked_radius**2 /
                                         Flist[i]**2)**0.5)
                    elif profile_plot.axis in [2, 4]:
                        # The click was in a spectrum plot
                        clicked_wave = man_plot.xcoo
                    if clicked_wave is not None:
                        # Add this wavelength to the added array
                        man_waves.append(clicked_wave)
                        man_intys.append(0)
                        man_sigs.append(1)
                        man_real.append(False)
                        selected_line = len(man_waves)-1

                # Switch back to automatic
                if man_plot.key == "x":
                    break

    # Close all of the images
    for image in images:
        image.close()

    # Subtract the ring profiles and update headers
    for i in range(len(fnlist)):
        image = FPImage(fnlist[i], update=True)
        arad = image.arad
        axcen = image.axcen
        aycen = image.aycen
        mask = image.inty == 0  # For re-correcting chip gaps
        for j in range(len(final_fitted_waves[i])):
            image.header['subwave'+repr(j)] = final_fitted_waves[i][j]
            image.header['subinty'+repr(j)] = final_fitted_intys[i][j]
            image.header['subsig'+repr(j)] = final_fitted_sigs[i][j]
            ringimg = (final_fitted_intys[i][j] *
                       nGauss(wavearraylist[i]-final_fitted_waves[i][j],
                              final_fitted_sigs[i][j]))
            image.inty[aycen-arad:aycen+arad, axcen-arad:axcen+arad] -= ringimg
            image.ringtog = "True"
        image.inty[mask] = 0  # For re-correcting chip gaps
        image.close()

    # Restore the old keyword shortcut
    plt.rcParams["keymap.save"] = oldsavekey
    plt.rcParams['keymap.fullscreen'] = oldfullscreenkey

    return
コード例 #4
0
def fit_wave_soln(fnlist):
    """Fits a wavelength solution to rings in a set of images. Appends this
    wavelength solution to the image headers as the keywords:
    'fpcala', 'fpcalb', ... 'fpcalf'
    
    Each object in fnlist must have a corresponding "median.fits" in its
    image directory, or this routine will not work.
    
    ARC ring images are fitted by adjusting the center, while the center is held
    fixed for night sky rings. A combination of fits to both sets of rings is
    used to determine a wavelength solution for the whole set of images.
    
    If the ARC rings disagree substantially with the night sky rings, it is
    recommended that users delete the ARC rings from the fit and use only the
    night sky rings.
    
    It is also known that the wavelength solution can sometimes be piecewise in
    time when a large jump in 'z' happens between two images; i.e. different
    wavelength solutions exist before and after the jump. The routine allows
    the user to make a piecewise solution for this reason, but this ability
    should be used sparingly.
    
    This routine contains one of the few hard-coded numbers in the pipeline,
    Fguess=5600. Currently F values are not written to the fits image headers,
    and this is a reasonable guess.
    
    Inputs:
    fnlist -> List of strings, each the path to a fits image. These images
    should all have been taken with the same order filter. If not, the routine
    will crash.
    
    """
    
    #This bit takes care of the 's' to save shortcut in matplotlib.
    oldsavekey = plt.rcParams["keymap.save"]
    plt.rcParams["keymap.save"] = ""
        
    #Open all of the images
    imagelist = []
    arclist = []
    objlist = []
    for i in range(len(fnlist)):
        imagelist.append(openfits(fnlist[i]))
        if i == 0: filt = imagelist[0][0].header["FILTER"]
        if imagelist[i][0].header["FILTER"] != filt:
            print "Error! Some of these images are in different filters!"
            crash()
        if imagelist[i][0].header["OBJECT"]=="ARC": arclist.append(imagelist[i])
        else:
            if not isfile(join(split(fnlist[i])[0],"median.fits")):
                print "Error! No 'median.fits' file found."
                crash()
            medimage = openfits(join(split(fnlist[i])[0],"median.fits"))
            imagelist[i][0].data += -medimage[0].data
            medimage.close()
            objlist.append(imagelist[i])
    
    #Load wavelength libraries
    arclib, nightlib = get_libraries(filt)
    if arclib is None:
        print "Error! Your filter isn't the wavelength library!"
        crash()

    



    #This next bit fits all of the rings that the user marks

    #Fit rings in the object images
    radlists = []
    for i in range(len(objlist)):
        radlists.append([])
    i=0
    while True:
        xgrid, ygrid = np.meshgrid(np.arange(objlist[i][0].data.shape[1]), np.arange(objlist[i][0].data.shape[0]))
        xcen = objlist[i][0].header["FPXCEN"]
        ycen = objlist[i][0].header["FPYCEN"]
        axcen = objlist[i][0].header["FPAXCEN"]
        aycen = objlist[i][0].header["FPAYCEN"]
        arad = objlist[i][0].header["FPARAD"]
        rgrid = np.sqrt((xgrid - xcen)**2 + (ygrid - ycen)**2)
        rbins = np.arange(arad-np.int(max(abs(axcen-xcen),abs(aycen-ycen))))+1
        intbins = np.empty_like(rbins)
        for j in range(len(rbins)):
            intbins[j] = np.median(objlist[i][0].data[np.logical_and(np.logical_and(objlist[i][0].data!=0,rgrid<rbins[j]),rgrid>rbins[j]-1)])
        ringplot = PlotRingProfile(objlist[i][0].data[aycen-arad:aycen+arad,axcen-arad:axcen+arad], #Data to be plotted. Only want stuff inside aperture
                                   rbins+(xcen-axcen)+arad, #Radii bins shifted to image center
                                   intbins*arad/np.percentile(np.abs(intbins),98)+(ycen-aycen)+arad, #Intensity bins, rescaled and shifted by image center
                                   xcen-axcen+arad, ycen-aycen+arad, #Shifted center
                                   radlists[i], #Previously fitted rings
                                   repr(i+1)+"/"+repr(len(objlist))) #numstring
        #Changing images and loop breakout conditions
        if ringplot.key == "d": i+=1
        if ringplot.key == "a": i+=-1
        if i == -1 or i == len(objlist):
            while True:
                yn = raw_input("Finished marking sky rings? (y/n) ")
                if "n" in yn or "N" in yn:
                    if i == -1: i=0
                    if i == len(objlist): i = len(objlist)-1
                    break
                elif "y" in yn or "Y" in yn:
                    break
        if i == -1 or i == len(objlist): break
        #Force-marking a ring
        if ringplot.key == "e" and ringplot.xcoo != None: radlists[i].append(ringplot.xcoo-arad-(xcen-axcen))
        #Deleting a ring
        if ringplot.key == "s" and ringplot.xcoo != None and len(radlists[i])>0:
            radlists[i].pop(np.argmin(np.array(radlists[i])-np.sqrt((ringplot.xcoo-arad-(xcen-axcen))**2 + (ringplot.ycoo-arad-(ycen-aycen))**2)))
        #Fitting a ring profile
        if ringplot.key == "w" and ringplot.xcoo != None:
            x = rbins[max(ringplot.xcoo-arad-(xcen-axcen)-50,0):min(ringplot.xcoo-arad-(xcen-axcen)+50,len(rbins))]**2
            y = intbins[max(ringplot.xcoo-arad-(xcen-axcen)-50,0):min(ringplot.xcoo-arad-(xcen-axcen)+50,len(rbins))]
            fit = GaussFit(x,y)
            fitplot = PlotRingFit(x,y,fit)
            if fitplot.key == "w": radlists[i].append(np.sqrt(fit[2]))
    zo = []
    to = []
    ro = []
    for i in range(len(objlist)):
        for j in range(len(radlists[i])):
            zo.append(objlist[i][0].header["ET1Z"])
            to.append(objlist[i][0].header["JD"])
            ro.append(radlists[i][j])
            
    #Fit rings in the ARC images
    xcen = objlist[0][0].header["FPXCEN"]
    ycen = objlist[0][0].header["FPYCEN"]
    radlists = []
    for i in range(len(arclist)):
        radlists.append([])
    i=0
    while True:
        xgrid, ygrid = np.meshgrid(np.arange(arclist[i][0].data.shape[1]), np.arange(arclist[i][0].data.shape[0]))
        axcen = arclist[i][0].header["FPAXCEN"]
        aycen = arclist[i][0].header["FPAYCEN"]
        arad = arclist[i][0].header["FPARAD"]
        rgrid = np.sqrt((xgrid - xcen)**2 + (ygrid - ycen)**2)
        rbins = np.arange(arad-np.int(max(abs(axcen-xcen),abs(aycen-ycen))))+1
        intbins = np.empty_like(rbins)
        for j in range(len(rbins)):
            intbins[j] = np.median(arclist[i][0].data[np.logical_and(np.logical_and(arclist[i][0].data!=0,rgrid<rbins[j]),rgrid>rbins[j]-1)])
        ringplot = PlotRingProfile(arclist[i][0].data[aycen-arad:aycen+arad,axcen-arad:axcen+arad], #Data to be plotted. Only want stuff inside aperture
                                   rbins+(xcen-axcen)+arad, #Radii bins shifted to image center
                                   intbins*arad/np.percentile(np.abs(intbins),98)+(ycen-aycen)+arad, #Intensity bins, rescaled and shifted by image center
                                   xcen-axcen+arad, ycen-aycen+arad, #Shifted center
                                   radlists[i], #Previously fitted rings
                                   repr(i+1)+"/"+repr(len(arclist))) #numstring
        #Changing images and loop breakout conditions
        if ringplot.key == "d": i+=1
        if ringplot.key == "a": i+=-1
        if i == -1 or i == len(arclist):
            while True:
                yn = raw_input("Finished marking ARC rings? (y/n) ")
                if "n" in yn or "N" in yn:
                    if i == -1: i=0
                    if i == len(arclist): i = len(arclist)-1
                    break
                elif "y" in yn or "Y" in yn:
                    break
        if i == -1 or i == len(arclist): break
        #Force-marking a ring
        if ringplot.key == "e" and ringplot.xcoo != None: radlists[i].append(ringplot.xcoo-arad-(xcen-axcen))
        #Deleting a ring
        if ringplot.key == "s" and ringplot.xcoo != None and len(radlists[i])>0:
            radlists[i].pop(np.argmin(np.array(radlists[i])-np.sqrt((ringplot.xcoo-arad-(xcen-axcen))**2 + (ringplot.ycoo-arad-(ycen-aycen))**2)))
        #Fitting a ring profile
        if ringplot.key == "w" and ringplot.xcoo != None:
            x = rbins[max(ringplot.xcoo-arad-(xcen-axcen)-50,0):min(ringplot.xcoo-arad-(xcen-axcen)+50,len(rbins))]**2
            y = intbins[max(ringplot.xcoo-arad-(xcen-axcen)-50,0):min(ringplot.xcoo-arad-(xcen-axcen)+50,len(rbins))]
            fit = GaussFit(x,y)
            fitplot = PlotRingFit(x,y,fit)
            if fitplot.key == "w": radlists[i].append(np.sqrt(fit[2]))
    za = []
    ta = []
    ra = []
    for i in range(len(arclist)):
        for j in range(len(radlists[i])):
            za.append(arclist[i][0].header["ET1Z"])
            ta.append(arclist[i][0].header["JD"])
            ra.append(radlists[i][j])
    
#     #Load previous ring fits from a text file - COMMENT THIS OUT LATER
#     rr,zz,tt = np.loadtxt("test.out",unpack=True)
#     za = list(zz[zz>0])
#     ta = list(tt[zz>0])
#     ra = list(rr[zz>0])
#     zo = list(zz[zz<0])
#     to = list(tt[zz<0])
#     ro = list(rr[zz<0])

    #Now we try to get a good guess at the wavelengths
    
    #Get a good guess at which wavelengths are which
    Bguess = objlist[0][0].header["ET1B"]
    Fguess = 5600
    
    #Figure out A by matching rings to the wavelength libraries
    master_r = np.array(ro+ra)
    master_z = np.array(zo+za)
    wavematch = np.zeros_like(master_r)
    isnight = np.array([True]*len(ro)+[False]*len(ra))
    oldrms = 10000 #Really high initial RMS for comparisons
    for i in range(len(master_r)):
        if isnight[i]: lib = nightlib
        else: lib = arclib
        for j in range(len(lib)):
            #Assume the i'th ring is the j'th line
            Aguess = lib[j]*np.sqrt(1+master_r[i]**2/Fguess**2)-Bguess*master_z[i]
            #What are all of the other rings, given this A?
            waveguess = (Aguess+Bguess*master_z)/np.sqrt(1+master_r**2/Fguess**2)
            for k in range(len(master_r)):
                if isnight[k]: wavematch[k] = nightlib[np.argmin(np.abs(nightlib-waveguess[k]))]
                else: wavematch[k] = arclib[np.argmin(np.abs(arclib-waveguess[k]))]
            rms = np.sqrt(np.average((waveguess-wavematch)**2))
            if rms < oldrms:
                #This is the new best solution. Keep it!
                oldrms = rms
                bestA = Aguess
                master_wave = wavematch.copy()
    
    #Make more master arrays for the plotting
    master_t = np.array(to+ta)
    t0 = np.min(master_t)
    master_t += -t0
    master_t *= 24*60 #Convert to minutes
    master_color = np.array(len(ro)*["blue"]+len(ra)*["red"]) #Colors for plotting
    toggle = np.ones(len(master_r),dtype="bool")
    dotime = False
    time_dividers = []
    
    #Do the interactive plotting
    while True:
        rplot = master_r[toggle]
        zplot = master_z[toggle]
        tplot = master_t[toggle]
        colorplot = master_color[toggle]
        waveplot = master_wave[toggle]
        fitplot = np.zeros(len(waveplot))
        xs = np.zeros((3,len(rplot)))
        xs[0] = rplot
        xs[1] = zplot
        xs[2] = tplot
        fit = [0]*(len(time_dividers)+1)
        time_dividers = sorted(time_dividers)
        if len(time_dividers)>1: print "Warning: Too many time divisions is likely unphysical. Be careful!"
        for i in range(len(time_dividers)+1):
            #Create a slice for all of the wavelengths before this time divider
            #but after the one before it
            if len(time_dividers)==0: tslice = tplot==tplot
            elif i == 0: tslice = tplot<time_dividers[i]
            elif i==len(time_dividers): tslice = tplot>time_dividers[i-1]
            else: tslice = np.logical_and(tplot<time_dividers[i],tplot>time_dividers[i-1])
            if dotime:
                fit[i] = curve_fit(fpfunc_for_curve_fit_with_t, xs[:,tslice], waveplot[tslice], p0=(bestA,Bguess,0,Fguess))[0]
                fitplot[tslice] = fpfunc_for_curve_fit_with_t(xs[:,tslice], fit[i][0], fit[i][1], fit[i][2], fit[i][3])
            else:
                fit[i] = curve_fit(fpfunc_for_curve_fit, xs[:,tslice], waveplot[tslice], p0=(bestA,Bguess,Fguess))[0]
                fitplot[tslice] = fpfunc_for_curve_fit(xs[:,tslice], fit[i][0], fit[i][1], fit[i][2])
        resid = waveplot - fitplot
        solnplot = WaveSolnPlot(rplot,zplot,tplot,waveplot,resid,colorplot,time_dividers)
        #Breakout case
        if solnplot.key == "a":
            while True:
                for i in range(len(time_dividers)+1):
                    if dotime: print "Solution 1: A = "+str(fit[i][0])+", B = "+str(fit[i][1])+", E = "+str(fit[i][2])+", F = "+str(fit[i][3])
                    else: print "Solution 1: A = "+str(fit[i][0])+", B = "+str(fit[i][1])+", F = "+str(fit[i][2])
                print "Residual rms="+str(np.sqrt(np.average(resid**2)))+" for "+repr(len(time_dividers)+1)+" independent "+repr(3+dotime)+"-parameter fits to "+repr(len(rplot))+" rings."
                yn = raw_input("Accept wavelength solution? (y/n) ")
                if "n" in yn or "N" in yn:
                    break
                elif "y" in yn or "Y" in yn:
                    solnplot.key = "QUIT"
                    break
        if solnplot.key == "QUIT": break
        #Restore all points case
        if solnplot.key == "r": toggle = np.ones(len(master_r),dtype="bool")
        #Delete nearest point case
        if solnplot.key == "d" and solnplot.axis != None:
            #Figure out which plot was clicked in
            if solnplot.axis == 1:
                #Resid vs. z plot
                z_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = ((zplot-z_loc)/(np.max(zplot)-np.min(zplot)))**2 + ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2
            elif solnplot.axis == 2:
                #Resid vs. R plot
                r_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = ((rplot-r_loc)/(np.max(rplot)-np.min(rplot)))**2 + ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2
            elif solnplot.axis == 3:
                #Resit vs. T plot
                t_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = ((tplot-t_loc)/(np.max(tplot)-np.min(tplot)))**2 + ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2
            elif solnplot.axis == 4:
                #Resid vs. Wave plot
                wave_loc = solnplot.xcoo
                resid_loc = solnplot.ycoo
                dist2 = ((waveplot-wave_loc)/(np.max(waveplot)-np.min(waveplot)))**2 + ((resid-resid_loc)/(np.max(resid)-np.min(resid)))**2
            #Get the radius and time of the worst ring
            r_mask = rplot[dist2 == np.min(dist2)][0]
            t_mask = tplot[dist2 == np.min(dist2)][0]
            toggle[np.logical_and(master_r == r_mask, master_t == t_mask)] = False
        #Fit for time case
        if solnplot.key == "t": dotime = not dotime
        #Add time break
        if solnplot.key == "w":
            timeplot = TimePlot(tplot,resid,colorplot,time_dividers)
            if timeplot.xcoo != None: time_dividers.append(timeplot.xcoo)
        #Remove time breaks
        if solnplot.key == "q":
            time_dividers = []