Ejemplo n.º 1
0
def main(args):
    """ finds the best models of all standard stars in the frame
    and normlize the model flux. Output is written to a file and will be called for calibration.
    """

    log = get_logger()

    log.info("mag delta %s = %f (for the pre-selection of stellar models)" %
             (args.color, args.delta_color))
    log.info('multiprocess parallelizing with {} processes'.format(args.ncpu))

    # READ DATA
    ############################################
    # First loop through and group by exposure and spectrograph
    frames_by_expid = {}
    for filename in args.frames:
        log.info("reading %s" % filename)
        frame = io.read_frame(filename)
        expid = safe_read_key(frame.meta, "EXPID")
        camera = safe_read_key(frame.meta, "CAMERA").strip().lower()
        spec = camera[1]
        uniq_key = (expid, spec)
        if uniq_key in frames_by_expid.keys():
            frames_by_expid[uniq_key][camera] = frame
        else:
            frames_by_expid[uniq_key] = {camera: frame}

    frames = {}
    flats = {}
    skies = {}

    spectrograph = None
    starfibers = None
    starindices = None
    fibermap = None

    # For each unique expid,spec pair, get the logical OR of the FIBERSTATUS for all
    # cameras and then proceed with extracting the frame information
    # once we modify the fibermap FIBERSTATUS
    for (expid, spec), camdict in frames_by_expid.items():

        fiberstatus = None
        for frame in camdict.values():
            if fiberstatus is None:
                fiberstatus = frame.fibermap['FIBERSTATUS'].data.copy()
            else:
                fiberstatus |= frame.fibermap['FIBERSTATUS']

        for camera, frame in camdict.items():
            frame.fibermap['FIBERSTATUS'] |= fiberstatus
            # Set fibermask flagged spectra to have 0 flux and variance
            frame = get_fiberbitmasked_frame(frame,
                                             bitmask='stdstars',
                                             ivar_framemask=True)
            frame_fibermap = frame.fibermap
            frame_starindices = np.where(isStdStar(frame_fibermap))[0]

            #- Confirm that all fluxes have entries but trust targeting bits
            #- to get basic magnitude range correct
            keep = np.ones(len(frame_starindices), dtype=bool)

            for colname in ['FLUX_G', 'FLUX_R', 'FLUX_Z']:  #- and W1 and W2?
                keep &= frame_fibermap[colname][frame_starindices] > 10**(
                    (22.5 - 30) / 2.5)
                keep &= frame_fibermap[colname][frame_starindices] < 10**(
                    (22.5 - 0) / 2.5)

            frame_starindices = frame_starindices[keep]

            if spectrograph is None:
                spectrograph = frame.spectrograph
                fibermap = frame_fibermap
                starindices = frame_starindices
                starfibers = fibermap["FIBER"][starindices]

            elif spectrograph != frame.spectrograph:
                log.error("incompatible spectrographs %d != %d" %
                          (spectrograph, frame.spectrograph))
                raise ValueError("incompatible spectrographs %d != %d" %
                                 (spectrograph, frame.spectrograph))
            elif starindices.size != frame_starindices.size or np.sum(
                    starindices != frame_starindices) > 0:
                log.error("incompatible fibermap")
                raise ValueError("incompatible fibermap")

            if not camera in frames:
                frames[camera] = []

            frames[camera].append(frame)

    # possibly cleanup memory
    del frames_by_expid

    for filename in args.skymodels:
        log.info("reading %s" % filename)
        sky = io.read_sky(filename)
        camera = safe_read_key(sky.header, "CAMERA").strip().lower()
        if not camera in skies:
            skies[camera] = []
        skies[camera].append(sky)

    for filename in args.fiberflats:
        log.info("reading %s" % filename)
        flat = io.read_fiberflat(filename)
        camera = safe_read_key(flat.header, "CAMERA").strip().lower()

        # NEED TO ADD MORE CHECKS
        if camera in flats:
            log.warning(
                "cannot handle several flats of same camera (%s), will use only the first one"
                % camera)
            #raise ValueError("cannot handle several flats of same camera (%s)"%camera)
        else:
            flats[camera] = flat

    if starindices.size == 0:
        log.error("no STD star found in fibermap")
        raise ValueError("no STD star found in fibermap")

    log.info("found %d STD stars" % starindices.size)

    log.warning("Not using flux errors for Standard Star fits!")

    # DIVIDE FLAT AND SUBTRACT SKY , TRIM DATA
    ############################################
    # since poping dict, we need to copy keys to iterate over to avoid
    # RuntimeError due to changing dict
    frame_cams = list(frames.keys())
    for cam in frame_cams:

        if not cam in skies:
            log.warning("Missing sky for %s" % cam)
            frames.pop(cam)
            continue
        if not cam in flats:
            log.warning("Missing flat for %s" % cam)
            frames.pop(cam)
            continue

        flat = flats[cam]
        for frame, sky in zip(frames[cam], skies[cam]):
            frame.flux = frame.flux[starindices]
            frame.ivar = frame.ivar[starindices]
            frame.ivar *= (frame.mask[starindices] == 0)
            frame.ivar *= (sky.ivar[starindices] != 0)
            frame.ivar *= (sky.mask[starindices] == 0)
            frame.ivar *= (flat.ivar[starindices] != 0)
            frame.ivar *= (flat.mask[starindices] == 0)
            frame.flux *= (frame.ivar > 0)  # just for clean plots
            for star in range(frame.flux.shape[0]):
                ok = np.where((frame.ivar[star] > 0)
                              & (flat.fiberflat[star] != 0))[0]
                if ok.size > 0:
                    frame.flux[star] = frame.flux[star] / flat.fiberflat[
                        star] - sky.flux[star]
            frame.resolution_data = frame.resolution_data[starindices]

    # CHECK S/N
    ############################################
    # for each band in 'brz', record quadratic sum of median S/N across wavelength
    snr = dict()
    for band in ['b', 'r', 'z']:
        snr[band] = np.zeros(starindices.size)
    for cam in frames:
        band = cam[0].lower()
        for frame in frames[cam]:
            msnr = np.median(frame.flux * np.sqrt(frame.ivar) /
                             np.sqrt(np.gradient(frame.wave)),
                             axis=1)  # median SNR per sqrt(A.)
            msnr *= (msnr > 0)
            snr[band] = np.sqrt(snr[band]**2 + msnr**2)
    log.info("SNR(B) = {}".format(snr['b']))

    ###############################
    max_number_of_stars = 50
    min_blue_snr = 4.
    ###############################
    indices = np.argsort(snr['b'])[::-1][:max_number_of_stars]

    validstars = np.where(snr['b'][indices] > min_blue_snr)[0]

    #- TODO: later we filter on models based upon color, thus throwing
    #- away very blue stars for which we don't have good models.

    log.info("Number of stars with median stacked blue S/N > {} /sqrt(A) = {}".
             format(min_blue_snr, validstars.size))
    if validstars.size == 0:
        log.error("No valid star")
        sys.exit(12)

    validstars = indices[validstars]

    for band in ['b', 'r', 'z']:
        snr[band] = snr[band][validstars]

    log.info("BLUE SNR of selected stars={}".format(snr['b']))

    for cam in frames:
        for frame in frames[cam]:
            frame.flux = frame.flux[validstars]
            frame.ivar = frame.ivar[validstars]
            frame.resolution_data = frame.resolution_data[validstars]
    starindices = starindices[validstars]
    starfibers = starfibers[validstars]
    nstars = starindices.size
    fibermap = Table(fibermap[starindices])

    # MASK OUT THROUGHPUT DIP REGION
    ############################################
    mask_throughput_dip_region = True
    if mask_throughput_dip_region:
        wmin = 4300.
        wmax = 4500.
        log.warning(
            "Masking out the wavelength region [{},{}]A in the standard star fit"
            .format(wmin, wmax))
    for cam in frames:
        for frame in frames[cam]:
            ii = np.where((frame.wave >= wmin) & (frame.wave <= wmax))[0]
            if ii.size > 0:
                frame.ivar[:, ii] = 0

    # READ MODELS
    ############################################
    log.info("reading star models in %s" % args.starmodels)
    stdwave, stdflux, templateid, teff, logg, feh = io.read_stdstar_templates(
        args.starmodels)

    # COMPUTE MAGS OF MODELS FOR EACH STD STAR MAG
    ############################################

    #- Support older fibermaps
    if 'PHOTSYS' not in fibermap.colnames:
        log.warning('Old fibermap format; using defaults for missing columns')
        log.warning("    PHOTSYS = 'S'")
        log.warning("    MW_TRANSMISSION_G/R/Z = 1.0")
        log.warning("    EBV = 0.0")
        fibermap['PHOTSYS'] = 'S'
        fibermap['MW_TRANSMISSION_G'] = 1.0
        fibermap['MW_TRANSMISSION_R'] = 1.0
        fibermap['MW_TRANSMISSION_Z'] = 1.0
        fibermap['EBV'] = 0.0

    model_filters = dict()
    for band in ["G", "R", "Z"]:
        for photsys in np.unique(fibermap['PHOTSYS']):
            model_filters[band + photsys] = load_legacy_survey_filter(
                band=band, photsys=photsys)

    log.info("computing model mags for %s" % sorted(model_filters.keys()))
    model_mags = dict()
    fluxunits = 1e-17 * units.erg / units.s / units.cm**2 / units.Angstrom
    for filter_name, filter_response in model_filters.items():
        model_mags[filter_name] = filter_response.get_ab_magnitude(
            stdflux * fluxunits, stdwave)
    log.info("done computing model mags")

    # LOOP ON STARS TO FIND BEST MODEL
    ############################################
    linear_coefficients = np.zeros((nstars, stdflux.shape[0]))
    chi2dof = np.zeros((nstars))
    redshift = np.zeros((nstars))
    normflux = []

    star_mags = dict()
    star_unextincted_mags = dict()
    for band in ['G', 'R', 'Z']:
        star_mags[band] = 22.5 - 2.5 * np.log10(fibermap['FLUX_' + band])
        star_unextincted_mags[band] = 22.5 - 2.5 * np.log10(
            fibermap['FLUX_' + band] / fibermap['MW_TRANSMISSION_' + band])

    star_colors = dict()
    star_colors['G-R'] = star_mags['G'] - star_mags['R']
    star_colors['R-Z'] = star_mags['R'] - star_mags['Z']

    star_unextincted_colors = dict()
    star_unextincted_colors[
        'G-R'] = star_unextincted_mags['G'] - star_unextincted_mags['R']
    star_unextincted_colors[
        'R-Z'] = star_unextincted_mags['R'] - star_unextincted_mags['Z']

    fitted_model_colors = np.zeros(nstars)

    for star in range(nstars):

        log.info("finding best model for observed star #%d" % star)

        # np.array of wave,flux,ivar,resol
        wave = {}
        flux = {}
        ivar = {}
        resolution_data = {}
        for camera in frames:
            for i, frame in enumerate(frames[camera]):
                identifier = "%s-%d" % (camera, i)
                wave[identifier] = frame.wave
                flux[identifier] = frame.flux[star]
                ivar[identifier] = frame.ivar[star]
                resolution_data[identifier] = frame.resolution_data[star]

        # preselect models based on magnitudes
        photsys = fibermap['PHOTSYS'][star]
        if not args.color in ['G-R', 'R-Z']:
            raise ValueError('Unknown color {}'.format(args.color))
        bands = args.color.split("-")
        model_colors = model_mags[bands[0] + photsys] - model_mags[bands[1] +
                                                                   photsys]

        color_diff = model_colors - star_unextincted_colors[args.color][star]
        selection = np.abs(color_diff) < args.delta_color
        if np.sum(selection) == 0:
            log.warning("no model in the selected color range for this star")
            continue

        # smallest cube in parameter space including this selection (needed for interpolation)
        new_selection = (teff >= np.min(teff[selection])) & (teff <= np.max(
            teff[selection]))
        new_selection &= (logg >= np.min(logg[selection])) & (logg <= np.max(
            logg[selection]))
        new_selection &= (feh >= np.min(feh[selection])) & (feh <= np.max(
            feh[selection]))
        selection = np.where(new_selection)[0]

        log.info(
            "star#%d fiber #%d, %s = %f, number of pre-selected models = %d/%d"
            % (star, starfibers[star], args.color,
               star_unextincted_colors[args.color][star], selection.size,
               stdflux.shape[0]))

        # Match unextincted standard stars to data
        coefficients, redshift[star], chi2dof[star] = match_templates(
            wave,
            flux,
            ivar,
            resolution_data,
            stdwave,
            stdflux[selection],
            teff[selection],
            logg[selection],
            feh[selection],
            ncpu=args.ncpu,
            z_max=args.z_max,
            z_res=args.z_res,
            template_error=args.template_error)

        linear_coefficients[star, selection] = coefficients

        log.info(
            'Star Fiber: {}; TEFF: {:.3f}; LOGG: {:.3f}; FEH: {:.3f}; Redshift: {:g}; Chisq/dof: {:.3f}'
            .format(starfibers[star], np.inner(teff,
                                               linear_coefficients[star]),
                    np.inner(logg, linear_coefficients[star]),
                    np.inner(feh, linear_coefficients[star]), redshift[star],
                    chi2dof[star]))

        # Apply redshift to original spectrum at full resolution
        model = np.zeros(stdwave.size)
        redshifted_stdwave = stdwave * (1 + redshift[star])
        for i, c in enumerate(linear_coefficients[star]):
            if c != 0:
                model += c * np.interp(stdwave, redshifted_stdwave, stdflux[i])

        # Apply dust extinction to the model
        log.info("Applying MW dust extinction to star {} with EBV = {}".format(
            star, fibermap['EBV'][star]))
        model *= dust_transmission(stdwave, fibermap['EBV'][star])

        # Compute final color of dust-extincted model
        photsys = fibermap['PHOTSYS'][star]
        if not args.color in ['G-R', 'R-Z']:
            raise ValueError('Unknown color {}'.format(args.color))
        bands = args.color.split("-")
        model_mag1 = model_filters[bands[0] + photsys].get_ab_magnitude(
            model * fluxunits, stdwave)
        model_mag2 = model_filters[bands[1] + photsys].get_ab_magnitude(
            model * fluxunits, stdwave)
        fitted_model_colors[star] = model_mag1 - model_mag2
        if bands[0] == "R":
            model_magr = model_mag1
        elif bands[1] == "R":
            model_magr = model_mag2

        #- TODO: move this back into normalize_templates, at the cost of
        #- recalculating a model magnitude?

        # Normalize the best model using reported magnitude
        scalefac = 10**((model_magr - star_mags['R'][star]) / 2.5)

        log.info('scaling R mag {:.3f} to {:.3f} using scale {}'.format(
            model_magr, star_mags['R'][star], scalefac))
        normflux.append(model * scalefac)

    # Now write the normalized flux for all best models to a file
    normflux = np.array(normflux)

    fitted_stars = np.where(chi2dof != 0)[0]
    if fitted_stars.size == 0:
        log.error("No star has been fit.")
        sys.exit(12)

    data = {}
    data['LOGG'] = linear_coefficients[fitted_stars, :].dot(logg)
    data['TEFF'] = linear_coefficients[fitted_stars, :].dot(teff)
    data['FEH'] = linear_coefficients[fitted_stars, :].dot(feh)
    data['CHI2DOF'] = chi2dof[fitted_stars]
    data['REDSHIFT'] = redshift[fitted_stars]
    data['COEFF'] = linear_coefficients[fitted_stars, :]
    data['DATA_%s' % args.color] = star_colors[args.color][fitted_stars]
    data['MODEL_%s' % args.color] = fitted_model_colors[fitted_stars]
    data['BLUE_SNR'] = snr['b'][fitted_stars]
    data['RED_SNR'] = snr['r'][fitted_stars]
    data['NIR_SNR'] = snr['z'][fitted_stars]
    io.write_stdstar_models(args.outfile, normflux, stdwave,
                            starfibers[fitted_stars], data)
Ejemplo n.º 2
0
def main(args):

    log = get_logger()

    cmd = [
        'desi_compute_fluxcalibration',
    ]
    for key, value in args.__dict__.items():
        if value is not None:
            cmd += ['--' + key, str(value)]
    cmd = ' '.join(cmd)
    log.info(cmd)

    log.info("read frame")
    # read frame
    frame = read_frame(args.infile)

    # Set fibermask flagged spectra to have 0 flux and variance
    frame = get_fiberbitmasked_frame(frame,
                                     bitmask='flux',
                                     ivar_framemask=True)

    log.info("apply fiberflat")
    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat
    apply_fiberflat(frame, fiberflat)

    log.info("subtract sky")
    # read sky
    skymodel = read_sky(args.sky)

    # subtract sky
    subtract_sky(frame, skymodel)

    log.info("compute flux calibration")

    # read models
    model_flux, model_wave, model_fibers, model_metadata = read_stdstar_models(
        args.models)

    ok = np.ones(len(model_metadata), dtype=bool)

    if args.chi2cut > 0:
        log.info("Apply cut CHI2DOF<{}".format(args.chi2cut))
        ok &= (model_metadata["CHI2DOF"] < args.chi2cut)
    if args.delta_color_cut > 0:
        log.info("Apply cut |delta color|<{}".format(args.delta_color_cut))
        ok &= (np.abs(model_metadata["MODEL_G-R"] - model_metadata["DATA_G-R"])
               < args.delta_color_cut)
    if args.min_color is not None:
        log.info("Apply cut DATA_G-R>{}".format(args.min_color))
        ok &= (model_metadata["DATA_G-R"] > args.min_color)
    if args.chi2cut_nsig > 0:
        # automatically reject stars that ar chi2 outliers
        mchi2 = np.median(model_metadata["CHI2DOF"])
        rmschi2 = np.std(model_metadata["CHI2DOF"])
        maxchi2 = mchi2 + args.chi2cut_nsig * rmschi2
        log.info("Apply cut CHI2DOF<{} based on chi2cut_nsig={}".format(
            maxchi2, args.chi2cut_nsig))
        ok &= (model_metadata["CHI2DOF"] <= maxchi2)

    ok = np.where(ok)[0]
    if ok.size == 0:
        log.error("cuts discarded all stars")
        sys.exit(12)
    nstars = model_flux.shape[0]
    nbad = nstars - ok.size
    if nbad > 0:
        log.warning("discarding %d star(s) out of %d because of cuts" %
                    (nbad, nstars))
        model_flux = model_flux[ok]
        model_fibers = model_fibers[ok]
        model_metadata = model_metadata[:][ok]

    # check that the model_fibers are actually standard stars
    fibermap = frame.fibermap

    ## check whether star fibers from args.models are consistent with fibers from fibermap
    ## if not print the OBJTYPE from fibermap for the fibers numbers in args.models and exit
    fibermap_std_indices = np.where(isStdStar(fibermap))[0]
    if np.any(~np.in1d(model_fibers % 500, fibermap_std_indices)):
        target_colnames, target_masks, survey = main_cmx_or_sv(fibermap)
        colname = target_colnames[0]
        for i in model_fibers % 500:
            log.error(
                "inconsistency with spectrum {}, OBJTYPE={}, {}={} in fibermap"
                .format(i, fibermap["OBJTYPE"][i], colname,
                        fibermap[colname][i]))
        sys.exit(12)

    # Make sure the fibers of interest aren't entirely masked.
    if np.sum(
            np.sum(frame.ivar[model_fibers % 500, :] == 0, axis=1) ==
            frame.nwave) == len(model_fibers):
        log.warning('All standard-star spectra are masked!')
        return

    fluxcalib = compute_flux_calibration(
        frame,
        model_wave,
        model_flux,
        model_fibers % 500,
        highest_throughput_nstars=args.highest_throughput)

    # QA
    if (args.qafile is not None):
        log.info("performing fluxcalib QA")
        # Load
        qaframe = load_qa_frame(args.qafile,
                                frame_meta=frame.meta,
                                flavor=frame.meta['FLAVOR'])
        # Run
        #import pdb; pdb.set_trace()
        qaframe.run_qa('FLUXCALIB', (frame, fluxcalib))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_fluxcalib(args.qafig, qaframe, frame, fluxcalib)

    # write result
    write_flux_calibration(args.outfile, fluxcalib, header=frame.meta)

    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 3
0
def compute_fiberflat(frame, nsig_clipping=10., accuracy=5.e-4, minval=0.1, maxval=10.,max_iterations=15,smoothing_res=5.,max_bad=100,max_rej_it=5,min_sn=0,diag_epsilon=1e-3) :
    """Compute fiber flat by deriving an average spectrum and dividing all fiber data by this average.
    Input data are expected to be on the same wavelength grid, with uncorrelated noise.
    They however do not have exactly the same resolution.

    Args:
        frame (desispec.Frame): input Frame object with attributes
            wave, flux, ivar, resolution_data
        nsig_clipping : [optional] sigma clipping value for outlier rejection
        accuracy : [optional] accuracy of fiberflat (end test for the iterative loop)
        minval: [optional] mask pixels with flux < minval * median fiberflat.
        maxval: [optional] mask pixels with flux > maxval * median fiberflat.
        max_iterations: [optional] maximum number of iterations
        smoothing_res: [optional] spacing between spline fit nodes for smoothing the fiberflat
        max_bad: [optional] mask entire fiber if more than max_bad-1 initially unmasked pixels are masked during the iterations
        max_rej_it: [optional] reject at most the max_rej_it worst pixels in each iteration
        min_sn: [optional] mask portions with signal to noise less than min_sn
        diag_epsilon: [optional] size of the regularization term in the deconvolution


    Returns:
        desispec.FiberFlat object with attributes
            wave, fiberflat, ivar, mask, meanspec

    Notes:
    - we first iteratively :

       - compute a deconvolved mean spectrum
       - compute a fiber flat using the resolution convolved mean spectrum for each fiber
       - smooth the fiber flat along wavelength
       - clip outliers

    - then we compute a fiberflat at the native fiber resolution (not smoothed)

    - the routine returns the fiberflat, its inverse variance , mask, and the deconvolved mean spectrum

    - the fiberflat is the ratio data/mean , so this flat should be divided to the data

    NOTE THAT THIS CODE HAS NOT BEEN TESTED WITH ACTUAL FIBER TRANSMISSION VARIATIONS,
    OUTLIER PIXELS, DEAD COLUMNS ...
    """
    log=get_logger()
    log.info("starting")

    #
    # chi2 = sum_(fiber f) sum_(wavelenght i) w_fi ( D_fi - F_fi (R_f M)_i )
    #
    # where
    # w = inverse variance
    # D = flux data (at the resolution of the fiber)
    # F = smooth fiber flat
    # R = resolution data
    # M = mean deconvolved spectrum
    #
    # M = A^{-1} B
    # with
    # A_kl = sum_(fiber f) sum_(wavelenght i) w_fi F_fi^2 (R_fki R_fli)
    # B_k = sum_(fiber f) sum_(wavelenght i) w_fi D_fi F_fi R_fki
    #
    # defining R'_fi = sqrt(w_fi) F_fi R_fi
    # and      D'_fi = sqrt(w_fi) D_fi
    #
    # A = sum_(fiber f) R'_f R'_f^T
    # B = sum_(fiber f) R'_f D'_f
    # (it's faster that way, and we try to use sparse matrices as much as possible)
    #

    #- if problematic fibers, set ivars to 0 and mask them with specmask.BADFIBER                 
    frame = get_fiberbitmasked_frame(frame,bitmask='flat',ivar_framemask=True)

    #- Shortcuts                                                                                                      
    nwave=frame.nwave
    nfibers=frame.nspec
    wave = frame.wave.copy()  #- this will become part of output too
    ivar = frame.ivar.copy()
    flux = frame.flux.copy()


    # iterative fitting and clipping to get precise mean spectrum


    # we first need to iterate to converge on a solution of mean spectrum
    # and smooth fiber flat. several interations are needed when
    # throughput AND resolution vary from fiber to fiber.
    # the end test is that the fiber flat has varied by less than accuracy
    # of previous iteration for all wavelength
    # we also have a max. number of iterations for this code

    nout_tot=0
    chi2pdf = 0.

    smooth_fiberflat=np.ones((flux.shape))

    chi2=np.zeros((flux.shape))

    ## mask low sn portions
    w = flux*np.sqrt(ivar)<min_sn
    ivar[w]=0

    ## 0th pass: reject pixels according to minval and maxval
    mean_spectrum = np.zeros(flux.shape[1])
    nbad=np.zeros(nfibers,dtype=int)
    for iteration in range(max_iterations):
        for i in range(flux.shape[1]):
            w = ivar[:,i]>0
            if w.sum()>0:
                mean_spectrum[i] = np.median(flux[w,i])

        nbad_it=0
        for fib in range(nfibers):
            w = ((flux[fib,:]<minval*mean_spectrum) | (flux[fib,:]>maxval*mean_spectrum)) & (ivar[fib,:]>0)
            nbad_it+=w.sum()
            nbad[fib]+=w.sum()

            if w.sum()>0:
                ivar[fib,w]=0
                log.warning("0th pass: masking {} pixels in fiber {}".format(w.sum(),fib))
            if nbad[fib]>=max_bad:
                ivar[fib,:]=0
                log.warning("0th pass: masking entire fiber {} (nbad={})".format(fib,nbad[fib]))
        if nbad_it == 0:
            break

    # 1st pass is median for spectrum, flat field without resolution
    # outlier rejection
    for iteration in range(max_iterations) :

        # use median for spectrum
        mean_spectrum=np.zeros((flux.shape[1]))
        for i in range(flux.shape[1]) :
            w=ivar[:,i]>0
            if w.sum() > 0 :
                mean_spectrum[i]=np.median(flux[w,i])

        nbad_it=0
        sum_chi2 = 0
        # not more than max_rej_it pixels per fiber at a time
        for fib in range(nfibers) :
            w=ivar[fib,:]>0
            if w.sum()==0:
                continue
            F = flux[fib,:]*0
            w=(mean_spectrum!=0) & (ivar[fib,:]>0)
            F[w]= flux[fib,w]/mean_spectrum[w]
            try :
                smooth_fiberflat[fib,:] = spline_fit(wave,wave[w],F[w],smoothing_res,ivar[fib,w]*mean_spectrum[w]**2,max_resolution=1.5*smoothing_res)
            except ValueError as err  :
                log.error("Error when smoothing the flat")
                log.error("Setting ivar=0 for fiber {} because spline fit failed".format(fib))
                ivar[fib,:] *= 0
            chi2 = ivar[fib,:]*(flux[fib,:]-mean_spectrum*smooth_fiberflat[fib,:])**2
            w=np.isnan(chi2)
            bad=np.where(chi2>nsig_clipping**2)[0]
            if bad.size>0 :
                if bad.size>max_rej_it : # not more than 5 pixels at a time
                    ii=np.argsort(chi2[bad])
                    bad=bad[ii[-max_rej_it:]]
                ivar[fib,bad] = 0
                log.warning("1st pass: rejecting {} pixels from fiber {}".format(len(bad),fib))
                nbad[fib]+=len(bad)
                if nbad[fib]>=max_bad:
                    ivar[fib,:]=0
                    log.warning("1st pass: rejecting fiber {} due to too many (new) bad pixels".format(fib))
                nbad_it+=len(bad)

            sum_chi2+=chi2.sum()
        ndf=int((ivar>0).sum()-nwave-nfibers*(nwave/smoothing_res))
        chi2pdf=0.
        if ndf>0 :
            chi2pdf=sum_chi2/ndf
        log.info("1st pass iter #{} chi2={}/{} chi2pdf={} nout={} (nsig={})".format(iteration,sum_chi2,ndf,chi2pdf,nbad_it,nsig_clipping))

        if nbad_it == 0 :
            break
    ## flatten fiberflat
    ## normalize smooth_fiberflat:
    mean=np.ones(smooth_fiberflat.shape[1])
    for i in range(smooth_fiberflat.shape[1]):
        w=ivar[:,i]>0
        if w.sum()>0:
            mean[i]=np.median(smooth_fiberflat[w,i])
    smooth_fiberflat = smooth_fiberflat/mean

    median_spectrum = mean_spectrum*1.

    previous_smooth_fiberflat = smooth_fiberflat*0
    previous_max_diff = 0.
    log.info("after 1st pass : nout = %d/%d"%(np.sum(ivar==0),np.size(ivar.flatten())))
    # 2nd pass is full solution including deconvolved spectrum, no outlier rejection
    for iteration in range(max_iterations) :
        ## reset sum_chi2
        sum_chi2=0
        log.info("2nd pass, iter %d : mean deconvolved spectrum"%iteration)

        # fit mean spectrum
        A=scipy.sparse.lil_matrix((nwave,nwave)).tocsr()
        B=np.zeros((nwave))

        # diagonal sparse matrix with content = sqrt(ivar)*flat of a given fiber
        SD=scipy.sparse.lil_matrix((nwave,nwave))

        # this is to go a bit faster
        sqrtwflat=np.sqrt(ivar)*smooth_fiberflat

        # loop on fiber to handle resolution (this is long)
        for fiber in range(nfibers) :
            if fiber%100==0 :
                log.info("2nd pass, filling matrix, iter %d fiber %d"%(iteration,fiber))

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]
            SD.setdiag(sqrtwflat[fiber])

            sqrtwflatR = SD*R # each row r of R is multiplied by sqrtwflat[r]

            A = A+(sqrtwflatR.T*sqrtwflatR).tocsr()
            B += sqrtwflatR.T.dot(np.sqrt(ivar[fiber])*flux[fiber])
        A_pos_def = A.todense()
        log.info("deconvolving")
        w = A.diagonal() > 0

        A_pos_def = A_pos_def[w,:]
        A_pos_def = A_pos_def[:,w]
        mean_spectrum = np.zeros(nwave)
        try:
            mean_spectrum[w]=cholesky_solve(A_pos_def,B[w])
        except:
            mean_spectrum[w]=np.linalg.lstsq(A_pos_def,B[w])[0]
            log.info("cholesky failes, trying svd inverse in iter {}".format(iteration))

        for fiber in range(nfibers) :

            if np.sum(ivar[fiber]>0)==0 :
                continue

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]

            M = R.dot(mean_spectrum)
            ok=(M!=0) & (ivar[fiber,:]>0)
            if ok.sum()==0:
                continue
            try :
                smooth_fiberflat[fiber] = spline_fit(wave,wave[ok],flux[fiber,ok]/M[ok],smoothing_res,ivar[fiber,ok]*M[ok]**2,max_resolution=1.5*smoothing_res)*(ivar[fiber,:]*M**2>0)
            except ValueError as err  :
                log.error("Error when smoothing the flat")
                log.error("Setting ivar=0 for fiber {} because spline fit failed".format(fiber))
                ivar[fiber,:] *= 0
            chi2 = ivar[fiber]*(flux[fiber]-smooth_fiberflat[fiber]*M)**2
            sum_chi2 += chi2.sum()
            w=np.isnan(smooth_fiberflat[fiber])
            if w.sum()>0:
                ivar[fiber]=0
                smooth_fiberflat[fiber]=1

        # normalize to get a mean fiberflat=1
        mean = np.ones(smooth_fiberflat.shape[1])
        for i in range(nwave):
            w = ivar[:,i]>0
            if w.sum()>0:
                mean[i]=np.median(smooth_fiberflat[w,i])
        ok=np.where(mean!=0)[0]
        smooth_fiberflat[:,ok] /= mean[ok]

        # this is the max difference between two iterations
        max_diff=np.max(np.abs(smooth_fiberflat-previous_smooth_fiberflat)*(ivar>0.))
        previous_smooth_fiberflat=smooth_fiberflat.copy()

        ndf=int(np.sum(ivar>0)-nwave-nfibers*(nwave/smoothing_res))
        chi2pdf=0.
        if ndf>0 :
            chi2pdf=sum_chi2/ndf
        log.info("2nd pass, iter %d, chi2=%f ndf=%d chi2pdf=%f"%(iteration,sum_chi2,ndf,chi2pdf))


        if max_diff<accuracy :
            break

        if np.abs(max_diff-previous_max_diff)<accuracy*0.1 :
            log.warning("no significant improvement on max diff, quit loop")
            break
        
        previous_max_diff=max_diff
        
        log.info("2nd pass, iter %d, max diff. = %g > requirement = %g, continue iterating"%(iteration,max_diff,accuracy))



    
    log.info("Total number of masked pixels=%d"%nout_tot)
    log.info("3rd pass, final computation of fiber flat")
    
    # now use mean spectrum to compute flat field correction without any smoothing
    # because sharp feature can arise if dead columns

    fiberflat=np.ones((flux.shape))
    fiberflat_ivar=np.zeros((flux.shape))
    mask=np.zeros((flux.shape), dtype='uint32')

    # reset ivar
    ivar = frame.ivar.copy()
    
    fiberflat_mask=12 # place holder for actual mask bit when defined

    nsig_for_mask=nsig_clipping # only mask out N sigma outliers

    for fiber in range(nfibers) :

        if np.sum(ivar[fiber]>0)==0 :
            continue

        ### R = Resolution(resolution_data[fiber])
        R = frame.R[fiber]
        M = np.array(np.dot(R.todense(),mean_spectrum)).flatten()
        fiberflat[fiber] = (M!=0)*flux[fiber]/(M+(M==0)) + (M==0)
        fiberflat_ivar[fiber] = ivar[fiber]*M**2
        nbad_tot=0
        iteration=0
        while iteration<500 :
            w=fiberflat_ivar[fiber,:]>0
            if w.sum()<100:
                break
            try :
                smooth_fiberflat=spline_fit(wave,wave[w],fiberflat[fiber,w],smoothing_res,fiberflat_ivar[fiber,w])
            except ValueError as e :
                print("error in spline_fit")
                mask[fiber] += fiberflat_mask
                fiberflat_ivar[fiber] = 0.
                break
                
            chi2=fiberflat_ivar[fiber]*(fiberflat[fiber]-smooth_fiberflat)**2
            bad=np.where(chi2>nsig_for_mask**2)[0]
            if bad.size>0 :
                
                nbadmax=1
                if bad.size>nbadmax : # not more than nbadmax pixels at a time
                    ii=np.argsort(chi2[bad])
                    bad=bad[ii[-nbadmax:]]

                mask[fiber,bad] += fiberflat_mask
                fiberflat_ivar[fiber,bad] = 0.
                nbad_tot += bad.size
            else :
                break
            iteration += 1

        
        log.info("3rd pass : fiber #%d , number of iterations %d"%(fiber,iteration))
    
    
    # set median flat to 1
    log.info("3rd pass : set median fiberflat to 1")

    mean=np.ones((flux.shape[1]))
    for i in range(flux.shape[1]) :
        ok=np.where((mask[:,i]==0)&(ivar[:,i]>0))[0]
        if ok.size > 0 :
            mean[i] = np.median(fiberflat[ok,i])
    ok=np.where(mean!=0)[0]
    for fiber in range(nfibers) :
        fiberflat[fiber,ok] /= mean[ok]  

    log.info("3rd pass : interpolating over masked pixels")


    for fiber in range(nfibers) :

        if np.sum(ivar[fiber]>0)==0 :
            continue
        # replace bad by smooth fiber flat
        bad=np.where((mask[fiber]>0)|(fiberflat_ivar[fiber]==0)|(fiberflat[fiber]<minval)|(fiberflat[fiber]>maxval))[0]
        
        if bad.size>0 :

            fiberflat_ivar[fiber,bad] = 0

            # find max length of segment with bad pix
            length=0
            for i in range(bad.size) :
                ib=bad[i]
                ilength=1
                tmp=ib
                for jb in bad[i+1:] :
                    if jb==tmp+1 :
                        ilength +=1
                        tmp=jb
                    else :
                        break
                length=max(length,ilength)
            if length>10 :
                log.info("3rd pass : fiber #%d has a max length of bad pixels=%d"%(fiber,length))
            smoothing_res=float(max(100,length))
            x=np.arange(wave.size)

            ok=fiberflat_ivar[fiber]>0
            if ok.sum()==0:
                continue
            try:
                smooth_fiberflat=spline_fit(x,x[ok],fiberflat[fiber,ok],smoothing_res,fiberflat_ivar[fiber,ok])
                fiberflat[fiber,bad] = smooth_fiberflat[bad]
            except:
                fiberflat[fiber,bad] = 1
                fiberflat_ivar[fiber,bad]=0

        if nbad_tot>0 :
            log.info("3rd pass : fiber #%d masked pixels = %d (%d iterations)"%(fiber,nbad_tot,iteration))

    # set median flat to 1
    log.info("set median fiberflat to 1")

    mean=np.ones((flux.shape[1]))
    for i in range(flux.shape[1]) :
        ok=np.where((mask[:,i]==0)&(ivar[:,i]>0))[0]
        if ok.size > 0 :
            mean[i] = np.median(fiberflat[ok,i])
    ok=np.where(mean!=0)[0]
    for fiber in range(nfibers) :
        fiberflat[fiber,ok] /= mean[ok]

    log.info("done fiberflat")

    log.info("add a systematic error of 0.0035 to fiberflat variance (calibrated on sims)")
    fiberflat_ivar = (fiberflat_ivar>0)/( 1./ (fiberflat_ivar+(fiberflat_ivar==0) ) + 0.0035**2)

    fiberflat = FiberFlat(wave, fiberflat, fiberflat_ivar, mask, mean_spectrum,
                     chi2pdf=chi2pdf,header=frame.meta,fibermap=frame.fibermap)
        
    #for broken_fiber in broken_fibers :
    #    log.info("mask broken fiber {} in flat".format(broken_fiber))
    #    fiberflat.fiberflat[fiber]=1.
    #    fiberflat.ivar[fiber]=0.
    #    fiberflat.mask[fiber]=specmask.BADFIBERFLAT
    
    return fiberflat
Ejemplo n.º 4
0
def main(args):

    log = get_logger()

    cmd = [
        'desi_compute_fluxcalibration',
    ]
    for key, value in args.__dict__.items():
        if value is not None:
            cmd += ['--' + key, str(value)]
    cmd = ' '.join(cmd)
    log.info(cmd)

    log.info("read frame")
    # read frame
    frame = read_frame(args.infile)

    # Set fibermask flagged spectra to have 0 flux and variance
    frame = get_fiberbitmasked_frame(frame,
                                     bitmask='flux',
                                     ivar_framemask=True)

    log.info("apply fiberflat")
    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat
    apply_fiberflat(frame, fiberflat)

    log.info("subtract sky")
    # read sky
    skymodel = read_sky(args.sky)

    # subtract sky
    subtract_sky(frame, skymodel)

    log.info("compute flux calibration")

    # read models
    model_flux, model_wave, model_fibers, model_metadata = read_stdstar_models(
        args.models)

    ok = np.ones(len(model_metadata), dtype=bool)

    if args.chi2cut > 0:
        log.info("apply cut CHI2DOF<{}".format(args.chi2cut))
        good = (model_metadata["CHI2DOF"] < args.chi2cut)
        bad = ~good
        ok &= good
        if np.any(bad):
            log.info(" discard {} stars with CHI2DOF= {}".format(
                np.sum(bad), list(model_metadata["CHI2DOF"][bad])))

    legacy_filters = ('G-R', 'R-Z')
    gaia_filters = ('GAIA-BP-RP', 'GAIA-G-RP')
    model_column_list = model_metadata.columns.names
    if args.color is None:
        if 'MODEL_G-R' in model_column_list:
            color = 'G-R'
        elif 'MODEL_GAIA-BP-RP' in model_column_list:
            log.info('Using Gaia filters')
            color = 'GAIA-BP-RP'
        else:
            log.error(
                "Can't find either G-R or BP-RP color in the model file.")
            sys.exit(15)
    else:
        if args.color not in legacy_filters and args.color not in gaia_filters:
            log.error(
                'Color name {} is not allowed, must be one of {} {}'.format(
                    args.color, legacy_filters, gaia_filters))
            sys.exit(14)
        color = args.color
        if color not in model_column_list:
            # This should't happen
            log.error(
                'The color {} was not computed in the models'.format(color))
            sys.exit(16)

    if args.delta_color_cut > 0:
        log.info("apply cut |delta color|<{}".format(args.delta_color_cut))
        good = (np.abs(model_metadata["MODEL_" + color] -
                       model_metadata["DATA_" + color]) < args.delta_color_cut)
        bad = ok & (~good)
        ok &= good
        if np.any(bad):
            vals = model_metadata["MODEL_" +
                                  color][bad] - model_metadata["DATA_" +
                                                               color][bad]
            log.info(" discard {} stars with dcolor= {}".format(
                np.sum(bad), list(vals)))

    if args.min_color is not None:
        log.info("apply cut DATA_{}>{}".format(color, args.min_color))
        good = (model_metadata["DATA_{}".format(color)] > args.min_color)
        bad = ok & (~good)
        ok &= good
        if np.any(bad):
            vals = model_metadata["DATA_{}".format(color)][bad]
            log.info(" discard {} stars with {}= {}".format(
                np.sum(bad), color, list(vals)))

    if args.chi2cut_nsig > 0:
        # automatically reject stars that ar chi2 outliers
        mchi2 = np.median(model_metadata["CHI2DOF"])
        rmschi2 = np.std(model_metadata["CHI2DOF"])
        maxchi2 = mchi2 + args.chi2cut_nsig * rmschi2
        log.info("apply cut CHI2DOF<{} based on chi2cut_nsig={}".format(
            maxchi2, args.chi2cut_nsig))
        good = (model_metadata["CHI2DOF"] <= maxchi2)
        bad = ok & (~good)
        ok &= good
        if np.any(bad):
            log.info(" discard {} stars with CHI2DOF={}".format(
                np.sum(bad), list(model_metadata["CHI2DOF"][bad])))

    ok = np.where(ok)[0]
    if ok.size == 0:
        log.error("selection cuts discarded all stars")
        sys.exit(12)
    nstars = model_flux.shape[0]
    nbad = nstars - ok.size
    if nbad > 0:
        log.warning("discarding %d star(s) out of %d because of cuts" %
                    (nbad, nstars))
        model_flux = model_flux[ok]
        model_fibers = model_fibers[ok]
        model_metadata = model_metadata[:][ok]

    # check that the model_fibers are actually standard stars
    fibermap = frame.fibermap

    ## check whether star fibers from args.models are consistent with fibers from fibermap
    ## if not print the OBJTYPE from fibermap for the fibers numbers in args.models and exit
    fibermap_std_indices = np.where(isStdStar(fibermap))[0]
    if np.any(~np.in1d(model_fibers % 500, fibermap_std_indices)):
        target_colnames, target_masks, survey = main_cmx_or_sv(fibermap)
        colname = target_colnames[0]
        for i in model_fibers % 500:
            log.error(
                "inconsistency with spectrum {}, OBJTYPE={}, {}={} in fibermap"
                .format(i, fibermap["OBJTYPE"][i], colname,
                        fibermap[colname][i]))
        sys.exit(12)

    # Make sure the fibers of interest aren't entirely masked.
    if np.sum(
            np.sum(frame.ivar[model_fibers % 500, :] == 0, axis=1) ==
            frame.nwave) == len(model_fibers):
        log.warning('All standard-star spectra are masked!')
        return

    fluxcalib = compute_flux_calibration(
        frame,
        model_wave,
        model_flux,
        model_fibers % 500,
        highest_throughput_nstars=args.highest_throughput,
        exposure_seeing_fwhm=args.seeing_fwhm)

    # QA
    if (args.qafile is not None):

        from desispec.io import write_qa_frame
        from desispec.io.qa import load_qa_frame
        from desispec.qa import qa_plots

        log.info("performing fluxcalib QA")
        # Load
        qaframe = load_qa_frame(args.qafile,
                                frame_meta=frame.meta,
                                flavor=frame.meta['FLAVOR'])
        # Run
        #import pdb; pdb.set_trace()
        qaframe.run_qa('FLUXCALIB', (frame, fluxcalib))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_fluxcalib(args.qafig, qaframe, frame, fluxcalib)

    # record inputs
    frame.meta['IN_FRAME'] = shorten_filename(args.infile)
    frame.meta['IN_SKY'] = shorten_filename(args.sky)
    frame.meta['FIBERFLT'] = shorten_filename(args.fiberflat)
    frame.meta['STDMODEL'] = shorten_filename(args.models)

    # write result
    write_flux_calibration(args.outfile, fluxcalib, header=frame.meta)

    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 5
0
def main(args):

    log = get_logger()

    if (args.fiberflat is None) and (args.sky is None) and (args.calib is None):
        log.critical('no --fiberflat, --sky, or --calib; nothing to do ?!?')
        sys.exit(12)

    frame = read_frame(args.infile)

    #- Raw scores already added in extraction, but just in case they weren't
    #- it is harmless to rerun to make sure we have them.
    compute_and_append_frame_scores(frame,suffix="RAW")

    if args.cosmics_nsig>0 and args.sky==None : # Reject cosmics (otherwise do it after sky subtraction)
        log.info("cosmics ray 1D rejection")
        reject_cosmic_rays_1d(frame,args.cosmics_nsig)

    if args.fiberflat!=None :
        log.info("apply fiberflat")
        # read fiberflat
        fiberflat = read_fiberflat(args.fiberflat)

        # apply fiberflat to all fibers
        apply_fiberflat(frame, fiberflat)
        compute_and_append_frame_scores(frame,suffix="FFLAT")

    if args.sky!=None :

        # read sky
        skymodel=read_sky(args.sky)

        if args.cosmics_nsig>0 :

            # use a copy the frame (not elegant but robust)
            copied_frame = copy.deepcopy(frame)
            
            # first subtract sky without throughput correction
            subtract_sky(copied_frame, skymodel, apply_throughput_correction = False)

            # then find cosmics
            log.info("cosmics ray 1D rejection after sky subtraction")
            reject_cosmic_rays_1d(copied_frame,args.cosmics_nsig)

            # copy mask
            frame.mask = copied_frame.mask
            
            # and (re-)subtract sky, but just the correction term
            subtract_sky(frame, skymodel, apply_throughput_correction = (not args.no_sky_throughput_correction) )

        else :
            # subtract sky
            subtract_sky(frame, skymodel, apply_throughput_correction = (not args.no_sky_throughput_correction) )

        compute_and_append_frame_scores(frame,suffix="SKYSUB")

    if args.calib!=None :
        log.info("calibrate")
        # read calibration
        fluxcalib=read_flux_calibration(args.calib)
        # apply calibration
        apply_flux_calibration(frame, fluxcalib)

        # Ensure that ivars are set to 0 for all values if any designated
        # fibermask bit is set. Also flips a bits for each frame.mask value using specmask.BADFIBER
        frame = get_fiberbitmasked_frame(frame,bitmask="flux",ivar_framemask=True)
        compute_and_append_frame_scores(frame,suffix="CALIB")


    # save output
    write_frame(args.outfile, frame, units='10**-17 erg/(s cm2 Angstrom)')
    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 6
0
def main(args, comm=None):
    """ finds the best models of all standard stars in the frame
    and normlize the model flux. Output is written to a file and will be called for calibration.
    """

    log = get_logger()

    log.info("mag delta %s = %f (for the pre-selection of stellar models)" %
             (args.color, args.delta_color))

    if args.mpi or comm is not None:
        from mpi4py import MPI
        if comm is None:
            comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
        if rank == 0:
            log.info('mpi parallelizing with {} ranks'.format(size))
    else:
        comm = None
        rank = 0
        size = 1

    # disable multiprocess by forcing ncpu = 1 when using MPI
    if comm is not None:
        ncpu = 1
        if rank == 0:
            log.info('disabling multiprocess (forcing ncpu = 1)')
    else:
        ncpu = args.ncpu

    if ncpu > 1:
        if rank == 0:
            log.info(
                'multiprocess parallelizing with {} processes'.format(ncpu))

    if args.ignore_gpu and desispec.fluxcalibration.use_gpu:
        # Opt-out of GPU usage
        desispec.fluxcalibration.use_gpu = False
        if rank == 0:
            log.info('ignoring GPU')
    elif desispec.fluxcalibration.use_gpu:
        # Nothing to do here, GPU is used by default if available
        if rank == 0:
            log.info('using GPU')
    else:
        if rank == 0:
            log.info('GPU not available')

    std_targetids = None
    if args.std_targetids is not None:
        std_targetids = args.std_targetids

    # READ DATA
    ############################################
    # First loop through and group by exposure and spectrograph
    frames_by_expid = {}
    rows = list()
    for filename in args.frames:
        log.info("reading %s" % filename)
        frame = io.read_frame(filename)
        night = safe_read_key(frame.meta, "NIGHT")
        expid = safe_read_key(frame.meta, "EXPID")
        camera = safe_read_key(frame.meta, "CAMERA").strip().lower()
        rows.append((night, expid, camera))
        spec = camera[1]
        uniq_key = (expid, spec)
        if uniq_key in frames_by_expid.keys():
            frames_by_expid[uniq_key][camera] = frame
        else:
            frames_by_expid[uniq_key] = {camera: frame}

    input_frames_table = Table(rows=rows, names=('NIGHT', 'EXPID', 'TILEID'))

    frames = {}
    flats = {}
    skies = {}

    spectrograph = None
    starfibers = None
    starindices = None
    fibermap = None
    # For each unique expid,spec pair, get the logical OR of the FIBERSTATUS for all
    # cameras and then proceed with extracting the frame information
    # once we modify the fibermap FIBERSTATUS
    for (expid, spec), camdict in frames_by_expid.items():

        fiberstatus = None
        for frame in camdict.values():
            if fiberstatus is None:
                fiberstatus = frame.fibermap['FIBERSTATUS'].data.copy()
            else:
                fiberstatus |= frame.fibermap['FIBERSTATUS']

        for camera, frame in camdict.items():
            frame.fibermap['FIBERSTATUS'] |= fiberstatus
            # Set fibermask flagged spectra to have 0 flux and variance
            frame = get_fiberbitmasked_frame(frame,
                                             bitmask='stdstars',
                                             ivar_framemask=True)
            frame_fibermap = frame.fibermap
            if std_targetids is None:
                frame_starindices = np.where(isStdStar(frame_fibermap))[0]
            else:
                frame_starindices = np.nonzero(
                    np.isin(frame_fibermap['TARGETID'], std_targetids))[0]

            #- Confirm that all fluxes have entries but trust targeting bits
            #- to get basic magnitude range correct
            keep_legacy = np.ones(len(frame_starindices), dtype=bool)

            for colname in ['FLUX_G', 'FLUX_R', 'FLUX_Z']:  #- and W1 and W2?
                keep_legacy &= frame_fibermap[colname][
                    frame_starindices] > 10**((22.5 - 30) / 2.5)
                keep_legacy &= frame_fibermap[colname][
                    frame_starindices] < 10**((22.5 - 0) / 2.5)
            keep_gaia = np.ones(len(frame_starindices), dtype=bool)

            for colname in ['G', 'BP', 'RP']:  #- and W1 and W2?
                keep_gaia &= frame_fibermap[
                    'GAIA_PHOT_' + colname +
                    '_MEAN_MAG'][frame_starindices] > 10
                keep_gaia &= frame_fibermap[
                    'GAIA_PHOT_' + colname +
                    '_MEAN_MAG'][frame_starindices] < 20
            n_legacy_std = keep_legacy.sum()
            n_gaia_std = keep_gaia.sum()
            keep = keep_legacy | keep_gaia
            # accept both types of standards for the time being

            # keep the indices for gaia/legacy subsets
            gaia_indices = keep_gaia[keep]
            legacy_indices = keep_legacy[keep]

            frame_starindices = frame_starindices[keep]

            if spectrograph is None:
                spectrograph = frame.spectrograph
                fibermap = frame_fibermap
                starindices = frame_starindices
                starfibers = fibermap["FIBER"][starindices]

            elif spectrograph != frame.spectrograph:
                log.error("incompatible spectrographs {} != {}".format(
                    spectrograph, frame.spectrograph))
                raise ValueError("incompatible spectrographs {} != {}".format(
                    spectrograph, frame.spectrograph))
            elif starindices.size != frame_starindices.size or np.sum(
                    starindices != frame_starindices) > 0:
                log.error("incompatible fibermap")
                raise ValueError("incompatible fibermap")

            if not camera in frames:
                frames[camera] = []

            frames[camera].append(frame)

    # possibly cleanup memory
    del frames_by_expid

    for filename in args.skymodels:
        log.info("reading %s" % filename)
        sky = io.read_sky(filename)
        camera = safe_read_key(sky.header, "CAMERA").strip().lower()
        if not camera in skies:
            skies[camera] = []
        skies[camera].append(sky)

    for filename in args.fiberflats:
        log.info("reading %s" % filename)
        flat = io.read_fiberflat(filename)
        camera = safe_read_key(flat.header, "CAMERA").strip().lower()

        # NEED TO ADD MORE CHECKS
        if camera in flats:
            log.warning(
                "cannot handle several flats of same camera (%s), will use only the first one"
                % camera)
            #raise ValueError("cannot handle several flats of same camera (%s)"%camera)
        else:
            flats[camera] = flat

    # if color is not specified we decide on the fly
    color = args.color
    if color is not None:
        if color[:4] == 'GAIA':
            legacy_color = False
            gaia_color = True
        else:
            legacy_color = True
            gaia_color = False
        if n_legacy_std == 0 and legacy_color:
            raise Exception(
                'Specified Legacy survey color, but no legacy standards')
        if n_gaia_std == 0 and gaia_color:
            raise Exception('Specified gaia color, but no gaia stds')

    if starindices.size == 0:
        log.error("no STD star found in fibermap")
        raise ValueError("no STD star found in fibermap")
    log.info("found %d STD stars" % starindices.size)

    if n_legacy_std == 0:
        gaia_std = True
        if color is None:
            color = 'GAIA-BP-RP'
    else:
        gaia_std = False
        if color is None:
            color = 'G-R'
        if n_gaia_std > 0:
            log.info('Gaia standards found but not used')

    if gaia_std:
        # The name of the reference filter to which we normalize the flux
        ref_mag_name = 'GAIA-G'
        color_band1, color_band2 = ['GAIA-' + _ for _ in color[5:].split('-')]
        log.info(
            "Using Gaia standards with color {} and normalizing to {}".format(
                color, ref_mag_name))
        # select appropriate subset of standards
        starindices = starindices[gaia_indices]
        starfibers = starfibers[gaia_indices]
    else:
        ref_mag_name = 'R'
        color_band1, color_band2 = color.split('-')
        log.info("Using Legacy standards with color {} and normalizing to {}".
                 format(color, ref_mag_name))
        # select appropriate subset of standards
        starindices = starindices[legacy_indices]
        starfibers = starfibers[legacy_indices]

    # excessive check but just in case
    if not color in ['G-R', 'R-Z', 'GAIA-BP-RP', 'GAIA-G-RP']:
        raise ValueError('Unknown color {}'.format(color))

    # log.warning("Not using flux errors for Standard Star fits!")

    # DIVIDE FLAT AND SUBTRACT SKY , TRIM DATA
    ############################################
    # since poping dict, we need to copy keys to iterate over to avoid
    # RuntimeError due to changing dict
    frame_cams = list(frames.keys())
    for cam in frame_cams:

        if not cam in skies:
            log.warning("Missing sky for %s" % cam)
            frames.pop(cam)
            continue
        if not cam in flats:
            log.warning("Missing flat for %s" % cam)
            frames.pop(cam)
            continue

        flat = flats[cam]
        for frame, sky in zip(frames[cam], skies[cam]):
            frame.flux = frame.flux[starindices]
            frame.ivar = frame.ivar[starindices]
            frame.ivar *= (frame.mask[starindices] == 0)
            frame.ivar *= (sky.ivar[starindices] != 0)
            frame.ivar *= (sky.mask[starindices] == 0)
            frame.ivar *= (flat.ivar[starindices] != 0)
            frame.ivar *= (flat.mask[starindices] == 0)
            frame.flux *= (frame.ivar > 0)  # just for clean plots
            for star in range(frame.flux.shape[0]):
                ok = np.where((frame.ivar[star] > 0)
                              & (flat.fiberflat[star] != 0))[0]
                if ok.size > 0:
                    frame.flux[star] = frame.flux[star] / flat.fiberflat[
                        star] - sky.flux[star]
            frame.resolution_data = frame.resolution_data[starindices]

        nframes = len(frames[cam])
        if nframes > 1:
            # optimal weights for the coaddition = ivar*throughput, not directly ivar,
            # we estimate the relative throughput with median fluxes at this stage
            medflux = np.zeros(nframes)
            for i, frame in enumerate(frames[cam]):
                if np.sum(frame.ivar > 0) == 0:
                    log.error(
                        "ivar=0 for all std star spectra in frame {}-{:08d}".
                        format(cam, frame.meta["EXPID"]))
                else:
                    medflux[i] = np.median(frame.flux[frame.ivar > 0])
            log.debug("medflux = {}".format(medflux))
            medflux *= (medflux > 0)
            if np.sum(medflux > 0) == 0:
                log.error(
                    "mean median flux = 0, for all stars in fibers {}".format(
                        list(frames[cam][0].fibermap["FIBER"][starindices])))
                sys.exit(12)
            mmedflux = np.mean(medflux[medflux > 0])
            weights = medflux / mmedflux
            log.info("coadding {} exposures in cam {}, w={}".format(
                nframes, cam, weights))

            sw = np.zeros(frames[cam][0].flux.shape)
            swf = np.zeros(frames[cam][0].flux.shape)
            swr = np.zeros(frames[cam][0].resolution_data.shape)

            for i, frame in enumerate(frames[cam]):
                sw += weights[i] * frame.ivar
                swf += weights[i] * frame.ivar * frame.flux
                swr += weights[i] * frame.ivar[:,
                                               None, :] * frame.resolution_data
            coadded_frame = frames[cam][0]
            coadded_frame.ivar = sw
            coadded_frame.flux = swf / (sw + (sw == 0))
            coadded_frame.resolution_data = swr / ((sw +
                                                    (sw == 0))[:, None, :])
            frames[cam] = [coadded_frame]

    # CHECK S/N
    ############################################
    # for each band in 'brz', record quadratic sum of median S/N across wavelength
    snr = dict()
    for band in ['b', 'r', 'z']:
        snr[band] = np.zeros(starindices.size)
    for cam in frames:
        band = cam[0].lower()
        for frame in frames[cam]:
            msnr = np.median(frame.flux * np.sqrt(frame.ivar) /
                             np.sqrt(np.gradient(frame.wave)),
                             axis=1)  # median SNR per sqrt(A.)
            msnr *= (msnr > 0)
            snr[band] = np.sqrt(snr[band]**2 + msnr**2)
    log.info("SNR(B) = {}".format(snr['b']))

    ###############################
    max_number_of_stars = 50
    min_blue_snr = 4.
    ###############################
    indices = np.argsort(snr['b'])[::-1][:max_number_of_stars]

    validstars = np.where(snr['b'][indices] > min_blue_snr)[0]

    #- TODO: later we filter on models based upon color, thus throwing
    #- away very blue stars for which we don't have good models.

    log.info("Number of stars with median stacked blue S/N > {} /sqrt(A) = {}".
             format(min_blue_snr, validstars.size))
    if validstars.size == 0:
        log.error("No valid star")
        sys.exit(12)

    validstars = indices[validstars]

    for band in ['b', 'r', 'z']:
        snr[band] = snr[band][validstars]

    log.info("BLUE SNR of selected stars={}".format(snr['b']))

    for cam in frames:
        for frame in frames[cam]:
            frame.flux = frame.flux[validstars]
            frame.ivar = frame.ivar[validstars]
            frame.resolution_data = frame.resolution_data[validstars]
    starindices = starindices[validstars]
    starfibers = starfibers[validstars]
    nstars = starindices.size
    fibermap = Table(fibermap[starindices])

    # MASK OUT THROUGHPUT DIP REGION
    ############################################
    mask_throughput_dip_region = True
    if mask_throughput_dip_region:
        wmin = 4300.
        wmax = 4500.
        log.warning(
            "Masking out the wavelength region [{},{}]A in the standard star fit"
            .format(wmin, wmax))
    for cam in frames:
        for frame in frames[cam]:
            ii = np.where((frame.wave >= wmin) & (frame.wave <= wmax))[0]
            if ii.size > 0:
                frame.ivar[:, ii] = 0

    # READ MODELS
    ############################################
    log.info("reading star models in %s" % args.starmodels)
    stdwave, stdflux, templateid, teff, logg, feh = io.read_stdstar_templates(
        args.starmodels)

    # COMPUTE MAGS OF MODELS FOR EACH STD STAR MAG
    ############################################

    #- Support older fibermaps
    if 'PHOTSYS' not in fibermap.colnames:
        log.warning('Old fibermap format; using defaults for missing columns')
        log.warning("    PHOTSYS = 'S'")
        log.warning("    EBV = 0.0")
        fibermap['PHOTSYS'] = 'S'
        fibermap['EBV'] = 0.0

    if not np.in1d(np.unique(fibermap['PHOTSYS']), ['', 'N', 'S', 'G']).all():
        log.error('Unknown PHOTSYS found')
        raise Exception('Unknown PHOTSYS found')
    # Fetching Filter curves
    model_filters = dict()
    for band in ["G", "R", "Z"]:
        for photsys in np.unique(fibermap['PHOTSYS']):
            if photsys in ['N', 'S']:
                model_filters[band + photsys] = load_legacy_survey_filter(
                    band=band, photsys=photsys)
    if len(model_filters) == 0:
        log.info('No Legacy survey photometry identified in fibermap')

    # I will always load gaia data even if we are fitting LS standards only
    for band in ["G", "BP", "RP"]:
        model_filters["GAIA-" + band] = load_gaia_filter(band=band, dr=2)

    # Compute model mags on rank 0 and bcast result to other ranks
    # This sidesteps an OOM event on Cori Haswell with "-c 2"
    model_mags = None
    if rank == 0:
        log.info("computing model mags for %s" % sorted(model_filters.keys()))
        model_mags = dict()
        for filter_name in model_filters.keys():
            model_mags[filter_name] = get_magnitude(stdwave, stdflux,
                                                    model_filters, filter_name)
        log.info("done computing model mags")
    if comm is not None:
        model_mags = comm.bcast(model_mags, root=0)

    # LOOP ON STARS TO FIND BEST MODEL
    ############################################
    star_mags = dict()
    star_unextincted_mags = dict()

    if gaia_std and (fibermap['EBV'] == 0).all():
        log.info("Using E(B-V) from SFD rather than FIBERMAP")
        # when doing gaia standards, on old tiles the
        # EBV is not set so we fetch from SFD (in original SFD scaling)
        ebv = SFDMap(scaling=1).ebv(
            acoo.SkyCoord(ra=fibermap['TARGET_RA'] * units.deg,
                          dec=fibermap['TARGET_DEC'] * units.deg))
    else:
        ebv = fibermap['EBV']

    photometric_systems = np.unique(fibermap['PHOTSYS'])
    if not gaia_std:
        for band in ['G', 'R', 'Z']:
            star_mags[band] = 22.5 - 2.5 * np.log10(fibermap['FLUX_' + band])
            star_unextincted_mags[band] = np.zeros(star_mags[band].shape)
            for photsys in photometric_systems:
                r_band = extinction_total_to_selective_ratio(
                    band, photsys)  # dimensionless
                # r_band = a_band / E(B-V)
                # E(B-V) is a difference of magnitudes (dimensionless)
                # a_band = -2.5*log10(effective dust transmission) , dimensionless
                # effective dust transmission =
                #                  integral( SED(lambda) * filter_transmission(lambda,band) * dust_transmission(lambda,E(B-V)) dlamdba)
                #                / integral( SED(lambda) * filter_transmission(lambda,band) dlamdba)
                selection = (fibermap['PHOTSYS'] == photsys)
                a_band = r_band * ebv[selection]  # dimensionless
                star_unextincted_mags[band][selection] = 22.5 - 2.5 * np.log10(
                    fibermap['FLUX_' + band][selection]) - a_band

    for band in ['G', 'BP', 'RP']:
        star_mags['GAIA-' + band] = fibermap['GAIA_PHOT_' + band + '_MEAN_MAG']

    for band, extval in gaia_extinction(star_mags['GAIA-G'],
                                        star_mags['GAIA-BP'],
                                        star_mags['GAIA-RP'], ebv).items():
        star_unextincted_mags['GAIA-' +
                              band] = star_mags['GAIA-' + band] - extval

    star_colors = dict()
    star_unextincted_colors = dict()

    # compute the colors and define the unextincted colors
    # the unextincted colors are filled later
    if not gaia_std:
        for c1, c2 in ['GR', 'RZ']:
            star_colors[c1 + '-' + c2] = star_mags[c1] - star_mags[c2]
            star_unextincted_colors[c1 + '-' +
                                    c2] = (star_unextincted_mags[c1] -
                                           star_unextincted_mags[c2])
    for c1, c2 in [('BP', 'RP'), ('G', 'RP')]:
        star_colors['GAIA-' + c1 + '-' + c2] = (star_mags['GAIA-' + c1] -
                                                star_mags['GAIA-' + c2])
        star_unextincted_colors['GAIA-' + c1 + '-' +
                                c2] = (star_unextincted_mags['GAIA-' + c1] -
                                       star_unextincted_mags['GAIA-' + c2])

    linear_coefficients = np.zeros((nstars, stdflux.shape[0]))
    chi2dof = np.zeros((nstars))
    redshift = np.zeros((nstars))
    normflux = np.zeros((nstars, stdwave.size))
    fitted_model_colors = np.zeros(nstars)

    local_comm, head_comm = None, None
    if comm is not None:
        # All ranks in local_comm work on the same stars
        local_comm = comm.Split(rank % nstars, rank)
        # The color 1 in head_comm contains all ranks that are have rank 0 in local_comm
        head_comm = comm.Split(rank < nstars, rank)

    for star in range(rank % nstars, nstars, size):

        log.info("rank %d: finding best model for observed star #%d" %
                 (rank, star))

        # np.array of wave,flux,ivar,resol
        wave = {}
        flux = {}
        ivar = {}
        resolution_data = {}
        for camera in frames:
            for i, frame in enumerate(frames[camera]):
                identifier = "%s-%d" % (camera, i)
                wave[identifier] = frame.wave
                flux[identifier] = frame.flux[star]
                ivar[identifier] = frame.ivar[star]
                resolution_data[identifier] = frame.resolution_data[star]

        # preselect models based on magnitudes
        photsys = fibermap['PHOTSYS'][star]

        if gaia_std:
            model_colors = model_mags[color_band1] - model_mags[color_band2]
        else:
            model_colors = model_mags[color_band1 +
                                      photsys] - model_mags[color_band2 +
                                                            photsys]

        color_diff = model_colors - star_unextincted_colors[color][star]
        selection = np.abs(color_diff) < args.delta_color
        if np.sum(selection) == 0:
            log.warning("no model in the selected color range for this star")
            continue

        # smallest cube in parameter space including this selection (needed for interpolation)
        new_selection = (teff >= np.min(teff[selection])) & (teff <= np.max(
            teff[selection]))
        new_selection &= (logg >= np.min(logg[selection])) & (logg <= np.max(
            logg[selection]))
        new_selection &= (feh >= np.min(feh[selection])) & (feh <= np.max(
            feh[selection]))
        selection = np.where(new_selection)[0]

        log.info(
            "star#%d fiber #%d, %s = %f, number of pre-selected models = %d/%d"
            % (star, starfibers[star], color,
               star_unextincted_colors[color][star], selection.size,
               stdflux.shape[0]))

        # Match unextincted standard stars to data
        match_templates_result = match_templates(
            wave,
            flux,
            ivar,
            resolution_data,
            stdwave,
            stdflux[selection],
            teff[selection],
            logg[selection],
            feh[selection],
            ncpu=ncpu,
            z_max=args.z_max,
            z_res=args.z_res,
            template_error=args.template_error,
            comm=local_comm)

        # Only local rank 0 can perform the remaining work
        if local_comm is not None and local_comm.Get_rank() != 0:
            continue

        coefficients, redshift[star], chi2dof[star] = match_templates_result
        linear_coefficients[star, selection] = coefficients
        log.info(
            'Star Fiber: {}; TEFF: {:.3f}; LOGG: {:.3f}; FEH: {:.3f}; Redshift: {:g}; Chisq/dof: {:.3f}'
            .format(starfibers[star], np.inner(teff,
                                               linear_coefficients[star]),
                    np.inner(logg, linear_coefficients[star]),
                    np.inner(feh, linear_coefficients[star]), redshift[star],
                    chi2dof[star]))

        # Apply redshift to original spectrum at full resolution
        model = np.zeros(stdwave.size)
        redshifted_stdwave = stdwave * (1 + redshift[star])
        for i, c in enumerate(linear_coefficients[star]):
            if c != 0:
                model += c * np.interp(stdwave, redshifted_stdwave, stdflux[i])

        # Apply dust extinction to the model
        log.info("Applying MW dust extinction to star {} with EBV = {}".format(
            star, ebv[star]))
        model *= dust_transmission(stdwave, ebv[star])

        # Compute final color of dust-extincted model
        photsys = fibermap['PHOTSYS'][star]

        if not gaia_std:
            model_mag1, model_mag2 = [
                get_magnitude(stdwave, model, model_filters, _ + photsys)
                for _ in [color_band1, color_band2]
            ]
        else:
            model_mag1, model_mag2 = [
                get_magnitude(stdwave, model, model_filters, _)
                for _ in [color_band1, color_band2]
            ]

        if color_band1 == ref_mag_name:
            model_magr = model_mag1
        elif color_band2 == ref_mag_name:
            model_magr = model_mag2
        else:
            # if the reference magnitude is not among colours
            # I'm fetching it separately. This will happen when
            # colour is BP-RP and ref magnitude is G
            if gaia_std:
                model_magr = get_magnitude(stdwave, model, model_filters,
                                           ref_mag_name)
            else:
                model_magr = get_magnitude(stdwave, model, model_filters,
                                           ref_mag_name + photsys)
        fitted_model_colors[star] = model_mag1 - model_mag2

        #- TODO: move this back into normalize_templates, at the cost of
        #- recalculating a model magnitude?

        cur_refmag = star_mags[ref_mag_name][star]

        # Normalize the best model using reported magnitude
        scalefac = 10**((model_magr - cur_refmag) / 2.5)

        log.info('scaling {} mag {:.3f} to {:.3f} using scale {}'.format(
            ref_mag_name, model_magr, cur_refmag, scalefac))
        normflux[star] = model * scalefac

    if head_comm is not None and rank < nstars:  # head_comm color is 1
        linear_coefficients = head_comm.reduce(linear_coefficients,
                                               op=MPI.SUM,
                                               root=0)
        redshift = head_comm.reduce(redshift, op=MPI.SUM, root=0)
        chi2dof = head_comm.reduce(chi2dof, op=MPI.SUM, root=0)
        fitted_model_colors = head_comm.reduce(fitted_model_colors,
                                               op=MPI.SUM,
                                               root=0)
        normflux = head_comm.reduce(normflux, op=MPI.SUM, root=0)

    # Check at least one star was fit. The check is peformed on rank 0 and
    # the result is bcast to other ranks so that all ranks exit together if
    # the check fails.
    atleastonestarfit = False
    if rank == 0:
        fitted_stars = np.where(chi2dof != 0)[0]
        atleastonestarfit = fitted_stars.size > 0
    if comm is not None:
        atleastonestarfit = comm.bcast(atleastonestarfit, root=0)
    if not atleastonestarfit:
        log.error("No star has been fit.")
        sys.exit(12)

    # Now write the normalized flux for all best models to a file
    if rank == 0:

        # get the fibermap from any input frame for the standard stars
        fibermap = Table(frame.fibermap)
        keep = np.isin(fibermap['FIBER'], starfibers[fitted_stars])
        fibermap = fibermap[keep]

        # drop fibermap columns specific to exposures instead of targets
        for col in [
                'DELTA_X', 'DELTA_Y', 'EXPTIME', 'NUM_ITER', 'FIBER_RA',
                'FIBER_DEC', 'FIBER_X', 'FIBER_Y'
        ]:
            if col in fibermap.colnames:
                fibermap.remove_column(col)

        data = {}
        data['LOGG'] = linear_coefficients[fitted_stars, :].dot(logg)
        data['TEFF'] = linear_coefficients[fitted_stars, :].dot(teff)
        data['FEH'] = linear_coefficients[fitted_stars, :].dot(feh)
        data['CHI2DOF'] = chi2dof[fitted_stars]
        data['REDSHIFT'] = redshift[fitted_stars]
        data['COEFF'] = linear_coefficients[fitted_stars, :]
        data['DATA_%s' % color] = star_colors[color][fitted_stars]
        data['MODEL_%s' % color] = fitted_model_colors[fitted_stars]
        data['BLUE_SNR'] = snr['b'][fitted_stars]
        data['RED_SNR'] = snr['r'][fitted_stars]
        data['NIR_SNR'] = snr['z'][fitted_stars]
        io.write_stdstar_models(args.outfile, normflux, stdwave,
                                starfibers[fitted_stars], data, fibermap,
                                input_frames_table)
Ejemplo n.º 7
0
def main(args):
    log = get_logger()

    if (args.fiberflat is None) and (args.sky is None) and (args.calib is
                                                            None):
        log.critical('no --fiberflat, --sky, or --calib; nothing to do ?!?')
        sys.exit(12)

    if (not args.no_tsnr) and (args.calib is None):
        log.critical(
            'need --fiberflat --sky and --calib to compute template SNR')
        sys.exit(12)

    frame = read_frame(args.infile)

    if not args.no_tsnr:
        # tsnr alpha calc. requires uncalibrated + no substraction rame.
        uncalibrated_frame = copy.deepcopy(frame)

    #- Raw scores already added in extraction, but just in case they weren't
    #- it is harmless to rerun to make sure we have them.
    compute_and_append_frame_scores(frame, suffix="RAW")

    if args.cosmics_nsig > 0 and args.sky == None:  # Reject cosmics (otherwise do it after sky subtraction)
        log.info("cosmics ray 1D rejection")
        reject_cosmic_rays_1d(frame, args.cosmics_nsig)

    if args.fiberflat != None:
        log.info("apply fiberflat")
        # read fiberflat
        fiberflat = read_fiberflat(args.fiberflat)

        # apply fiberflat to all fibers
        apply_fiberflat(frame, fiberflat)
        compute_and_append_frame_scores(frame, suffix="FFLAT")
    else:
        fiberflat = None

    if args.no_xtalk:
        zero_ivar = (not args.no_zero_ivar)
    else:
        zero_ivar = False

    if args.sky != None:

        # read sky
        skymodel = read_sky(args.sky)

        if args.cosmics_nsig > 0:

            # use a copy the frame (not elegant but robust)
            copied_frame = copy.deepcopy(frame)

            # first subtract sky without throughput correction
            subtract_sky(copied_frame,
                         skymodel,
                         apply_throughput_correction=False,
                         zero_ivar=zero_ivar)

            # then find cosmics
            log.info("cosmics ray 1D rejection after sky subtraction")
            reject_cosmic_rays_1d(copied_frame, args.cosmics_nsig)

            # copy mask
            frame.mask = copied_frame.mask

            # and (re-)subtract sky, but just the correction term
            subtract_sky(frame,
                         skymodel,
                         apply_throughput_correction=(
                             not args.no_sky_throughput_correction),
                         zero_ivar=zero_ivar)

        else:
            # subtract sky
            subtract_sky(frame,
                         skymodel,
                         apply_throughput_correction=(
                             not args.no_sky_throughput_correction),
                         zero_ivar=zero_ivar)

        compute_and_append_frame_scores(frame, suffix="SKYSUB")

    if not args.no_xtalk:
        log.info("fiber crosstalk correction")
        correct_fiber_crosstalk(frame, fiberflat)

        if not args.no_zero_ivar:
            frame.ivar *= (frame.mask == 0)

    if args.calib != None:
        log.info("calibrate")
        # read calibration
        fluxcalib = read_flux_calibration(args.calib)
        # apply calibration
        apply_flux_calibration(frame, fluxcalib)

        # Ensure that ivars are set to 0 for all values if any designated
        # fibermask bit is set. Also flips a bits for each frame.mask value using specmask.BADFIBER
        frame = get_fiberbitmasked_frame(
            frame, bitmask="flux", ivar_framemask=(not args.no_zero_ivar))
        compute_and_append_frame_scores(frame, suffix="CALIB")

    if not args.no_tsnr:
        log.info("calculating tsnr")
        results, alpha = calc_tsnr2(uncalibrated_frame,
                                    fiberflat=fiberflat,
                                    skymodel=skymodel,
                                    fluxcalib=fluxcalib,
                                    alpha_only=args.alpha_only)

        frame.meta['TSNRALPH'] = alpha

        comments = {k: "from calc_frame_tsnr" for k in results.keys()}
        append_frame_scores(frame, results, comments, overwrite=True)

    # record inputs
    frame.meta['IN_FRAME'] = shorten_filename(args.infile)
    frame.meta['FIBERFLT'] = shorten_filename(args.fiberflat)
    frame.meta['IN_SKY'] = shorten_filename(args.sky)
    frame.meta['IN_CALIB'] = shorten_filename(args.calib)

    # save output
    write_frame(args.outfile, frame, units='10**-17 erg/(s cm2 Angstrom)')
    log.info("successfully wrote %s" % args.outfile)