Пример #1
0
def get_new_frame(night, expid, camera, basedir, nsky, rep=0):
    '''For a given frame file, returns an updated frame file to the basedir with a certain number of sky fibers.
    Args:
        night: YYYYMMDD (float)
        expid: exposure id without padding zeros (float)
        camera: examples are b3, r4, z2, etc. (string)
        basedir: path to directory you want new frame to be written
        nsky: number of fibers flagged sky in new frame fibermap
    Options:
        rep: number of different frame files you want (with same number of sky fibers), default is 5
    Writes a new frame file with desispec.io.write_frame().'''

    framefile = desispec.io.findfile('frame', night, expid, camera=camera)
    print(framefile)
    header = fitsio.read_header(framefile)
    fiberflatfile = findcalibfile([
        header,
    ], 'FIBERFLAT')
    frame = desispec.io.read_frame(framefile)
    fiberflat = desispec.io.read_fiberflat(fiberflatfile)
    print(fiberflatfile)
    pick_sky_fibers(frame, fiberflat, nsky=nsky)

    #- output updated frame to current directory
    newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format(
        camera, expid, nsky, rep)
    desispec.io.write_frame(newframefile, frame)
Пример #2
0
def run_compute_sky(night, expid, cameras, basedir, nsky_list, reps=None):
    '''Generates sky models for new frame files, using --no-extra-variance option, which doesn't inflate the output errors for sky subtraction systematics.
    Args:
        night: YYYYMMDD (float)
        expid: exposure id without padding zeros (float)
        cameras: list of cameras corresponding to frame files in given directory, example ['r3', 'z3', 'b3'] (list or array)
        basedir: where to look for frame files
        nsky_list: list with different numbers of fibers frame files were generated with.
    Options:
        rep: number of different frame files for each camera and nsky combination. Default is 5
    '''

    if reps == None:
        reps = 5

    for cam in cameras:
        framefile = desispec.io.findfile('frame', night, expid, camera=cam)
        header = fitsio.read_header(framefile)
        fiberflatfile = findcalibfile([
            header,
        ], 'FIBERFLAT')
        for N in range(reps):
            for n in nsky_list:
                newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format(
                    cam, expid, n, N)
                skyfile = basedir + '/sky-{}-{:08d}-{}-{}.fits'.format(
                    cam, expid, n, N)
                cmd = 'desi_compute_sky -i {} --fiberflat {} -o {} --no-extra-variance'.format(
                    newframefile, fiberflatfile, skyfile)
                print('RUNNING {}'.format(cmd))
                err = subprocess.call(cmd.split())
                if err:
                    print('FAILED')
                else:
                    print('OK')
Пример #3
0
def correct_fiber_crosstalk(frame,fiberflat=None,xyset=None):
    """Apply a fiber cross talk correction. Modifies frame.flux and frame.ivar.

    Args:
        frame : desispec.frame.Frame object

    Optionnal:
    fiberflat : desispec.fiberflat.FiberFlat object
        xyset : desispec.xytraceset.XYTraceSet object with trace coordinates to shift the spectra
                (automatically found with calibration finder otherwise)
    """
    log=get_logger()

    params = read_crosstalk_parameters()

    if xyset is None :
        psf_filename = findcalibfile([frame.meta,],"PSF")
        xyset  = read_xytraceset(psf_filename)

    log.info("compute kernels")
    kernels = compute_crosstalk_kernels()

    contamination     = np.zeros(frame.flux.shape)
    contamination_var = np.zeros(frame.flux.shape)

    for dfiber in [-2,-1,1,2] :
        log.info("F{:+d}".format(dfiber))
        kernel = kernels[np.abs(dfiber)]
        cont,var = compute_contamination(frame,dfiber,kernel,params,xyset,fiberflat)
        contamination     += cont
        contamination_var += var

    frame.flux -= contamination
    frame_var  = 1./(frame.ivar + (frame.ivar==0))
    frame.ivar = (frame.ivar>0)/( frame_var + contamination_var )
Пример #4
0
def run_sky_subtraction(night, expid, cameras, basedir, nsky_list, reps=None):
    '''Runs sky subtraction with new sky models and frame files.
    Args:
        night: YYYYMMDD (float)
        expid: exposure id without padding zeros (float)
        cameras: list of cameras corresponding to frame and sky files in given directory, example ['r3', 'z3', 'b3'] (list or array)
        basedir: where to look for frame files
        nsky_list: list with different numbers of fibers frame files and sky files were generated with.
    Options:
        rep: number of different frame/sky files for each camera and nsky combination. Default is 5
    '''
    if reps == None:
        reps = 5

    for cam in cameras:
        framefile = desispec.io.findfile('frame', night, expid, camera=cam)
        header = fitsio.read_header(framefile)
        fiberflatfile = findcalibfile([
            header,
        ], 'FIBERFLAT')
        fiberflat = desispec.io.read_fiberflat(fiberflatfile)
        for N in range(reps):
            for n in nsky_list:
                newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format(
                    cam, expid, n, N)
                skyfile = basedir + '/sky-{}-{:08d}-{}-{}.fits'.format(
                    cam, expid, n, N)
                sframe = desispec.io.read_frame(newframefile)
                sky = desispec.io.read_sky(skyfile)
                apply_fiberflat(sframe, fiberflat)
                subtract_sky(sframe, sky)

                sframefile = basedir + '/sframe-{}-{:08d}-{}-{}.fits'.format(
                    cam, expid, n, N)
                desispec.io.write_frame(sframefile, sframe)
Пример #5
0
def calc_tsnr2_cframe(cframe):
    """
    Given cframe, calc_tsnr2 guessing frame,fiberflat,skymodel,fluxcalib to use

    Args:
        cframe: input cframe Frame object

    Returns (results, alpha) from calc_tsnr2
    """
    log = get_logger()
    dirname, filename = os.path.split(cframe.filename)
    framefile = os.path.join(dirname, filename.replace('cframe', 'frame'))
    skyfile = os.path.join(dirname, filename.replace('cframe', 'sky'))
    fluxcalibfile = os.path.join(dirname, filename.replace('cframe', 'fluxcalib'))

    for testfile in (framefile, skyfile, fluxcalibfile):
        if not os.path.exists(testfile):
            msg = 'missing {testfile}; unable to calculate TSNR2'
            log.error(msg)
            raise ValueError(msg)

    night = cframe.meta['NIGHT']
    expid = cframe.meta['EXPID']
    camera = cframe.meta['CAMERA']
    fiberflatfile = desispec.io.findfile('fiberflatnight', night, camera=camera)
    if not os.path.exists(fiberflatfile):
        ffname = os.path.basename(fiberflatfile)
        log.warning(f'{ffname} not found; using default calibs')
        fiberflatfile = findcalibfile([cframe.meta,], 'FIBERFLAT')

    frame = desispec.io.read_frame(framefile)
    fiberflat = desispec.io.read_fiberflat(fiberflatfile)
    skymodel = desispec.io.read_sky(skyfile)
    fluxcalib = desispec.io.read_flux_calibration(fluxcalibfile)

    return calc_tsnr2(frame, fiberflat, skymodel, fluxcalib)
Пример #6
0
def main(args, comm=None):

    log = get_logger()

    imgfile = args.input_image
    outfile = args.output_psf

    if args.input_psf is not None:
        inpsffile = args.input_psf
    else:
        from desispec.calibfinder import findcalibfile
        hdr = fits.getheader(imgfile)
        inpsffile = findcalibfile([
            hdr,
        ], 'PSF')

    optarray = []
    if args.extra is not None:
        optarray = args.extra.split()

    specmin = int(args.specmin)
    nspec = int(args.nspec)
    bundlesize = int(args.bundlesize)

    specmax = specmin + nspec

    # Now we divide our spectra into bundles

    checkbundles = set()
    checkbundles.update(
        np.floor_divide(np.arange(specmin, specmax),
                        bundlesize * np.ones(nspec)).astype(int))
    bundles = sorted(checkbundles)
    nbundle = len(bundles)

    bspecmin = {}
    bnspec = {}
    for b in bundles:
        if specmin > b * bundlesize:
            bspecmin[b] = specmin
        else:
            bspecmin[b] = b * bundlesize
        if (b + 1) * bundlesize > specmax:
            bnspec[b] = specmax - bspecmin[b]
        else:
            bnspec[b] = (b + 1) * bundlesize - bspecmin[b]

    # Now we assign bundles to processes

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    mynbundle = int(nbundle / nproc)
    myfirstbundle = 0
    leftover = nbundle % nproc
    if rank < leftover:
        mynbundle += 1
        myfirstbundle = rank * mynbundle
    else:
        myfirstbundle = ((mynbundle + 1) * leftover) + \
            (mynbundle * (rank - leftover))

    if rank == 0:
        # Print parameters
        log.info("specex:  using {} processes".format(nproc))
        log.info("specex:  input image = {}".format(imgfile))
        log.info("specex:  input PSF = {}".format(inpsffile))
        log.info("specex:  output = {}".format(outfile))
        log.info("specex:  bundlesize = {}".format(bundlesize))
        log.info("specex:  specmin = {}".format(specmin))
        log.info("specex:  specmax = {}".format(specmax))
        if args.broken_fibers:
            log.info("specex:  broken fibers = {}".format(args.broken_fibers))

    # get the root output file

    outpat = re.compile(r'(.*)\.fits')
    outmat = outpat.match(outfile)
    if outmat is None:
        raise RuntimeError("specex output file should have .fits extension")
    outroot = outmat.group(1)

    outdir = os.path.dirname(outroot)
    if rank == 0:
        if not os.path.isdir(outdir):
            os.makedirs(outdir)

    failcount = 0

    for b in range(myfirstbundle, myfirstbundle + mynbundle):
        outbundle = "{}_{:02d}".format(outroot, b)
        outbundlefits = "{}.fits".format(outbundle)
        com = ['desi_psf_fit']
        com.extend(['-a', imgfile])
        com.extend(['--in-psf', inpsffile])
        com.extend(['--out-psf', outbundlefits])
        com.extend(['--first-bundle', "{}".format(b)])
        com.extend(['--last-bundle', "{}".format(b)])
        com.extend(['--first-fiber', "{}".format(bspecmin[b])])
        com.extend(['--last-fiber', "{}".format(bspecmin[b] + bnspec[b] - 1)])
        if args.broken_fibers:
            com.extend(['--broken-fibers', "{}".format(args.broken_fibers)])
        if args.debug:
            com.extend(['--debug'])

        com.extend(optarray)

        log.debug("proc {} calling {}".format(rank, " ".join(com)))

        argc = len(com)
        arg_buffers = [ct.create_string_buffer(com[i].encode('ascii')) \
            for i in range(argc)]
        addrlist = [ ct.cast(x, ct.POINTER(ct.c_char)) for x in \
            map(ct.addressof, arg_buffers) ]
        arg_pointers = (ct.POINTER(ct.c_char) * argc)(*addrlist)

        retval = libspecex.cspecex_desi_psf_fit(argc, arg_pointers)

        if retval != 0:
            comstr = " ".join(com)
            log.error("desi_psf_fit on process {} failed with return "
                      "value {} running {}".format(rank, retval, comstr))
            failcount += 1

    if comm is not None:
        from mpi4py import MPI
        failcount = comm.allreduce(failcount, op=MPI.SUM)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("some bundles failed desi_psf_fit")

    if rank == 0:
        outfits = "{}.fits".format(outroot)

        inputs = ["{}_{:02d}.fits".format(outroot, x) for x in bundles]

        if args.disable_merge:
            log.info("don't merge")
        else:
            #- Empirically it appears that files written by one rank sometimes
            #- aren't fully buffer-flushed and closed before getting here,
            #- despite the MPI allreduce barrier.  Pause to let I/O catch up.
            log.info('HACK: taking a 20 sec pause before merging')
            sys.stdout.flush()
            time.sleep(20.)

            merge_psf(inputs, outfits)

            log.info('done merging')

            if failcount == 0:
                # only remove the per-bundle files if the merge was good
                for f in inputs:
                    if os.path.isfile(f):
                        os.remove(f)

    if comm is not None:
        failcount = comm.bcast(failcount, root=0)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("merging of per-bundle files failed")

    return
Пример #7
0
def calc_tsnr2(frame, fiberflat, skymodel, fluxcalib, alpha_only=False) :
    '''
    Compute template SNR^2 values for a given frame

    Args:
        frame : uncalibrated Frame object for one camera
        fiberflat : FiberFlat object
        sky : SkyModel object
        fluxcalib : FluxCalib object

    returns (tsnr2, alpha):
        `tsnr2` dictionary, with keys labeling tracer (bgs,elg,etc.), of values
        holding nfiber length array of the tsnr^2 values for this camera, and
        `alpha`, the relative weighting btwn rdnoise & sky terms to model var.

    Note:  Assumes DESIMODEL is set and up to date.
    '''
    global _camera_nea_angperpix
    global _band_ensemble

    log=get_logger()

    if not (frame.meta["BUNIT"]=="count/Angstrom" or frame.meta["BUNIT"]=="electron/Angstrom" ) :
        log.error("requires an uncalibrated frame")
        raise RuntimeError("requires an uncalibrated frame")

    camera=frame.meta["CAMERA"].strip().lower()
    band=camera[0]

    psfpath=findcalibfile([frame.meta],"PSF")
    psf=GaussHermitePSF(psfpath)

    # Returns bivariate spline to be evaluated at (fiber, wave).
    if not "DESIMODEL" in os.environ :
        msg = "requires $DESIMODEL to get the NEA and the SNR templates"
        log.error(msg)
        raise RuntimeError(msg)

    if _camera_nea_angperpix is None:
        _camera_nea_angperpix = dict()

    if camera in _camera_nea_angperpix:
        nea, angperpix = _camera_nea_angperpix[camera]
    else:
        neafilename=os.path.join(os.environ["DESIMODEL"],
                                 f"data/specpsf/nea/masternea_{camera}.fits")
        log.info("read NEA file {}".format(neafilename))
        nea, angperpix = read_nea(neafilename)
        _camera_nea_angperpix[camera] = nea, angperpix

    if _band_ensemble is None:
        _band_ensemble = dict()

    if band in _band_ensemble:
        ensemble = _band_ensemble[band]
    else:
        ensembledir=os.path.join(os.environ["DESIMODEL"],"data/tsnr")
        log.info("read TSNR ensemble files in {}".format(ensembledir))
        ensemble = get_ensemble(ensembledir, bands=[band,])
        _band_ensemble[band] = ensemble

    nspec, nwave = fluxcalib.calib.shape

    fibers = np.arange(nspec)
    rdnoise = fb_rdnoise(fibers, frame, psf)

    #
    ebv = frame.fibermap['EBV']

    if np.sum(ebv!=0)>0 :
        log.info("TSNR MEDIAN EBV = {:.3f}".format(np.median(ebv[ebv!=0])))
    else :
        log.info("TSNR MEDIAN EBV = 0")

    # Evaluate.
    npix = nea(fibers, frame.wave)
    angperpix = angperpix(fibers, frame.wave)
    angperspecbin = np.mean(np.gradient(frame.wave))

    for label, x in zip(['RDNOISE', 'NEA', 'ANGPERPIX', 'ANGPERSPECBIN'], [rdnoise, npix, angperpix, angperspecbin]):
        log.info('{} \t {:.3f} +- {:.3f}'.format(label.ljust(10), np.median(x), np.std(x)))

    # Relative weighting between rdnoise & sky terms to model var.
    alpha = calc_alpha(frame, fibermap=frame.fibermap,
                rdnoise_sigma=rdnoise, npix_1d=npix,
                angperpix=angperpix, angperspecbin=angperspecbin,
                fiberflat=fiberflat, skymodel=skymodel)

    log.info(f"TSNR ALPHA = {alpha:.6f}")

    if alpha_only:
        return {}, alpha

    maskfactor = np.ones_like(frame.mask, dtype=np.float)
    maskfactor[frame.mask > 0] = 0.0
    maskfactor *= (frame.ivar > 0.0)

    tsnrs = {}

    denom = var_model(rdnoise, npix, angperpix, angperspecbin, fiberflat, skymodel, alpha=alpha)

    for tracer in ensemble.keys():
        wave = ensemble[tracer].wave[band]
        dflux = ensemble[tracer].flux[band]

        if len(frame.wave) != len(wave) or not np.allclose(frame.wave, wave):
            log.warning(f'Resampling {tracer} ensemble wavelength to match input {camera} frame')
            tmp = np.zeros([dflux.shape[0], len(frame.wave)])
            for i in range(dflux.shape[0]):
                tmp[i] = np.interp(frame.wave, wave, dflux[i],
                            left=dflux[i,0], right=dflux[i,-1])
            dflux = tmp
            wave = frame.wave

        # Work in uncalibrated flux units (electrons per angstrom); flux_calib includes exptime. tau.
        # Broadcast.
        dflux = dflux * fluxcalib.calib # [e/A]

        # Wavelength dependent fiber flat;  Multiply or divide - check with Julien.
        result = dflux * fiberflat.fiberflat

        # Apply dust transmission.
        result *= dust_transmission(frame.wave, ebv)

        result = result**2.

        result /= denom

        # Eqn. (1) of https://desi.lbl.gov/DocDB/cgi-bin/private/RetrieveFile?docid=4723;filename=sky-monitor-mc-study-v1.pdf;version=2
        tsnrs[tracer] = np.sum(result * maskfactor, axis=1)

    results=dict()
    for tracer in tsnrs.keys():
        key = 'TSNR2_{}_{}'.format(tracer.upper(), band.upper())
        results[key]=tsnrs[tracer]
        log.info('{} = {:.6f}'.format(key, np.median(tsnrs[tracer])))

    return results, alpha
Пример #8
0
def main(args, comm=None):

    log = get_logger()

    #- only import when running, to avoid requiring specex install for import
    from specex.specex import run_specex

    imgfile = args.input_image
    outfile = args.output_psf

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    hdr = None
    if rank == 0:
        hdr = fits.getheader(imgfile)
    if comm is not None:
        hdr = comm.bcast(hdr, root=0)

    #- Locate line list in $SPECEXDATA or specex/data
    if 'SPECEXDATA' in os.environ:
        specexdata = os.environ['SPECEXDATA']
    else:
        from pkg_resources import resource_filename
        specexdata = resource_filename('specex', 'data')

    lamp_lines_file = os.path.join(specexdata, 'specex_linelist_desi.txt')

    if args.input_psf is not None:
        inpsffile = args.input_psf
    else:
        from desispec.calibfinder import findcalibfile
        inpsffile = findcalibfile([
            hdr,
        ], 'PSF')

    optarray = []
    if args.extra is not None:
        optarray = args.extra.split()

    specmin = int(args.specmin)
    nspec = int(args.nspec)
    bundlesize = int(args.bundlesize)

    specmax = specmin + nspec

    # Now we divide our spectra into bundles

    checkbundles = set()
    checkbundles.update(
        np.floor_divide(np.arange(specmin, specmax),
                        bundlesize * np.ones(nspec)).astype(int))
    bundles = sorted(checkbundles)
    nbundle = len(bundles)

    bspecmin = {}
    bnspec = {}
    for b in bundles:
        if specmin > b * bundlesize:
            bspecmin[b] = specmin
        else:
            bspecmin[b] = b * bundlesize
        if (b + 1) * bundlesize > specmax:
            bnspec[b] = specmax - bspecmin[b]
        else:
            bnspec[b] = (b + 1) * bundlesize - bspecmin[b]

    # Now we assign bundles to processes

    mynbundle = int(nbundle / nproc)
    leftover = nbundle % nproc
    if rank < leftover:
        mynbundle += 1
        myfirstbundle = bundles[0] + rank * mynbundle
    else:
        myfirstbundle = bundles[0] + ((mynbundle + 1) * leftover) + \
            (mynbundle * (rank - leftover))

    if rank == 0:
        # Print parameters
        log.info("specex:  using {} processes".format(nproc))
        log.info("specex:  input image = {}".format(imgfile))
        log.info("specex:  input PSF = {}".format(inpsffile))
        log.info("specex:  output = {}".format(outfile))
        log.info("specex:  bundlesize = {}".format(bundlesize))
        log.info("specex:  specmin = {}".format(specmin))
        log.info("specex:  specmax = {}".format(specmax))
        if args.broken_fibers:
            log.info("specex:  broken fibers = {}".format(args.broken_fibers))

    # get the root output file

    outpat = re.compile(r'(.*)\.fits')
    outmat = outpat.match(outfile)
    if outmat is None:
        raise RuntimeError("specex output file should have .fits extension")
    outroot = outmat.group(1)

    outdir = os.path.dirname(outroot)
    if rank == 0:
        if outdir != "":
            if not os.path.isdir(outdir):
                os.makedirs(outdir)

    cam = hdr["camera"].lower().strip()
    band = cam[0]

    failcount = 0

    for b in range(myfirstbundle, myfirstbundle + mynbundle):
        outbundle = "{}_{:02d}".format(outroot, b)
        outbundlefits = "{}.fits".format(outbundle)
        com = ['desi_psf_fit']
        com.extend(['-a', imgfile])
        com.extend(['--in-psf', inpsffile])
        com.extend(['--out-psf', outbundlefits])
        com.extend(['--lamp-lines', lamp_lines_file])
        com.extend(['--first-bundle', "{}".format(b)])
        com.extend(['--last-bundle', "{}".format(b)])
        com.extend(['--first-fiber', "{}".format(bspecmin[b])])
        com.extend(['--last-fiber', "{}".format(bspecmin[b] + bnspec[b] - 1)])
        if band == "z":
            com.extend(['--legendre-deg-wave', "{}".format(3)])
            com.extend(['--fit-continuum'])
        else:
            com.extend(['--legendre-deg-wave', "{}".format(1)])
        if args.broken_fibers:
            com.extend(['--broken-fibers', "{}".format(args.broken_fibers)])
        if args.debug:
            com.extend(['--debug'])

        com.extend(optarray)

        log.debug("proc {} calling {}".format(rank, " ".join(com)))

        retval = run_specex(com)

        if retval != 0:
            comstr = " ".join(com)
            log.error("desi_psf_fit on process {} failed with return "
                      "value {} running {}".format(rank, retval, comstr))
            failcount += 1

    if comm is not None:
        from mpi4py import MPI
        failcount = comm.allreduce(failcount, op=MPI.SUM)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("some bundles failed desi_psf_fit")

    if rank == 0:
        outfits = "{}.fits".format(outroot)

        inputs = ["{}_{:02d}.fits".format(outroot, x) for x in bundles]

        if args.disable_merge:
            log.info("don't merge")
        else:
            #- Empirically it appears that files written by one rank sometimes
            #- aren't fully buffer-flushed and closed before getting here,
            #- despite the MPI allreduce barrier.  Pause to let I/O catch up.
            log.info('5 sec pause before merging')
            sys.stdout.flush()
            time.sleep(5.)

            merge_psf(inputs, outfits)

            log.info('done merging')

            if failcount == 0:
                # only remove the per-bundle files if the merge was good
                for f in inputs:
                    if os.path.isfile(f):
                        os.remove(f)

    if comm is not None:
        failcount = comm.bcast(failcount, root=0)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("merging of per-bundle files failed")

    return
Пример #9
0
def main(args=None, comm=None):
    if args is None:
        args = parse()
    elif isinstance(args, (list, tuple)):
        args = parse(args)

    log = get_logger()

    start_mpi_connect = time.time()
    if comm is not None:
        #- Use the provided comm to determine rank and size
        rank = comm.rank
        size = comm.size
    else:
        #- Check MPI flags and determine the comm, rank, and size given the arguments
        comm, rank, size = assign_mpi(do_mpi=args.mpi,
                                      do_batch=args.batch,
                                      log=log)
    stop_mpi_connect = time.time()

    #- Start timer; only print log messages from rank 0 (others are silent)
    timer = desiutil.timer.Timer(silent=(rank > 0))

    #- Fill in timing information for steps before we had the timer created
    if args.starttime is not None:
        timer.start('startup', starttime=args.starttime)
        timer.stop('startup', stoptime=start_imports)

    timer.start('imports', starttime=start_imports)
    timer.stop('imports', stoptime=stop_imports)

    timer.start('mpi_connect', starttime=start_mpi_connect)
    timer.stop('mpi_connect', stoptime=stop_mpi_connect)

    #- Freeze IERS after parsing args so that it doesn't bother if only --help
    timer.start('freeze_iers')
    desiutil.iers.freeze_iers()
    timer.stop('freeze_iers')

    #- Preflight checks
    timer.start('preflight')

    # - Preflight checks
    if rank > 0:
        # - Let rank 0 fetch these, and then broadcast
        args, hdr, camhdr = None, None, None
    else:
        if args.inputs is None:
            if args.night is None or args.expids is None:
                raise RuntimeError(
                    'Must specify --inputs or --night AND --expids')
            else:
                args.expids = np.array(
                    args.expids.strip(' \t').split(',')).astype(int)
                args.inputs = []
                for expid in args.expids:
                    infile = findfile('raw', night=args.night, expid=expid)
                    args.inputs.append(infile)
                    if not os.path.isfile(infile):
                        raise IOError('Missing input file: {}'.format(infile))
        else:
            args.inputs = np.array(args.inputs.strip(' \t').split(','))
            #- args.night will be defined in update_args_with_headers,
            #- but let's define the expids here
            #- NOTE: inputs has priority. Overwriting expids if they existed.
            args.expids = []
            for infile in args.inputs:
                hdr = load_raw_data_header(pathname=infile,
                                           return_filehandle=False)
                args.expids.append(int(hdr['EXPID']))

        args.expids = np.sort(args.expids)
        args.inputs = np.sort(args.inputs)
        args.expid = args.expids[0]
        args.input = args.inputs[0]

        #- Use header information to fill in missing information in the arguments object
        args, hdr, camhdr = update_args_with_headers(args)

        #- If not a science observation, we don't need the hdr or camhdr objects,
        #- So let's not broadcast them to all the ranks
        if args.obstype != 'SCIENCE':
            hdr, camhdr = None, None

    if comm is not None:
        args = comm.bcast(args, root=0)
        hdr = comm.bcast(hdr, root=0)
        camhdr = comm.bcast(camhdr, root=0)

    known_obstype = ['SCIENCE', 'ARC', 'FLAT']
    if args.obstype not in known_obstype:
        raise RuntimeError('obstype {} not in {}'.format(
            args.obstype, known_obstype))

    timer.stop('preflight')

    # -------------------------------------------------------------------------
    # - Create and submit a batch job if requested

    if args.batch:
        #camword = create_camword(args.cameras)
        #exp_str = '-'.join('{:08d}'.format(expid) for expid in args.expids)
        if args.obstype.lower() == 'science':
            jobdesc = 'stdstarfit'
        elif args.obstype.lower() == 'arc':
            jobdesc = 'psfnight'
        elif args.obstype.lower() == 'flat':
            jobdesc = 'nightlyflat'
        else:
            jobdesc = args.obstype.lower()
        scriptfile = create_desi_proc_batch_script(night=args.night, exp=args.expids, cameras=args.cameras,\
                                                jobdesc=jobdesc, queue=args.queue, runtime=args.runtime,\
                                                batch_opts=args.batch_opts, timingfile=args.timingfile,
                                                system_name=args.system_name)
        err = 0
        if not args.nosubmit:
            err = subprocess.call(['sbatch', scriptfile])
        sys.exit(err)

    # -------------------------------------------------------------------------
    # - Proceed with running

    # - What are we going to do?
    if rank == 0:
        log.info('----------')
        log.info('Input {}'.format(args.inputs))
        log.info('Night {} expids {}'.format(args.night, args.expids))
        log.info('Obstype {}'.format(args.obstype))
        log.info('Cameras {}'.format(args.cameras))
        log.info('Output root {}'.format(desispec.io.specprod_root()))
        log.info('----------')

    # - Wait for rank 0 to make directories before proceeding
    if comm is not None:
        comm.barrier()

    # -------------------------------------------------------------------------
    # - Merge PSF of night if applicable

    if args.obstype in ['ARC']:
        timer.start('psfnight')
        num_cmd = num_err = 0
        if rank == 0:
            for camera in args.cameras:
                psfnightfile = findfile('psfnight', args.night, args.expids[0],
                                        camera)
                if not os.path.isfile(
                        psfnightfile
                ):  # we still don't have a psf night, see if we can compute it ...
                    psfs = [
                        findfile('psf', args.night, expid,
                                 camera).replace("psf", "fit-psf")
                        for expid in args.expids
                    ]
                    log.info(
                        "Number of PSF for night={} camera={} = {}".format(
                            args.night, camera, len(psfs)))
                    if len(psfs) > 4:  # lets do it!
                        log.info("Computing psfnight ...")
                        dirname = os.path.dirname(psfnightfile)
                        if not os.path.isdir(dirname):
                            os.makedirs(dirname)
                        num_cmd += 1

                        #- generic try/except so that any failure doesn't leave
                        #- MPI rank 0 hanging while others are waiting for it
                        try:
                            desispec.scripts.specex.mean_psf(
                                psfs, psfnightfile)
                        except:
                            log.error('specex.meanpsf failed for {}'.format(
                                os.path.basename(psfnightfile)))
                            exc_type, exc_value, exc_traceback = sys.exc_info()
                            lines = traceback.format_exception(
                                exc_type, exc_value, exc_traceback)
                            log.error(''.join(lines))
                            sys.stdout.flush()

                        if not os.path.exists(psfnightfile):
                            log.error(f'Failed to create {psfnightfile}')
                            num_err += 1
                    else:
                        log.info(
                            "Fewer than 4 psfs were provided, can't compute psfnight. Exiting ..."
                        )
                        num_cmd += 1
                        num_err += 1

        timer.stop('psfnight')

        num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm)
        if rank == 0:
            if num_err > 0:
                log.error(f'{num_err}/{num_cmd} psfnight commands failed')

        if num_err > 0 and num_err == num_cmd:
            sys.stdout.flush()
            if rank == 0:
                log.critical('All psfnight commands failed')
            sys.exit(1)

    # -------------------------------------------------------------------------
    # - Average and auto-calib fiberflats of night if applicable

    if args.obstype in ['FLAT']:
        timer.start('fiberflatnight')
        #- Track number of commands run and number of errors for exit code
        num_cmd = 0
        num_err = 0
        if rank == 0:
            fiberflatnightfile = findfile('fiberflatnight', args.night,
                                          args.expids[0], args.cameras[0])
            fiberflatdirname = os.path.dirname(fiberflatnightfile)
            if os.path.isfile(fiberflatnightfile):
                log.info("Fiberflatnight already exists. Exitting ...")
            elif len(
                    args.cameras
            ) < 6:  # we still don't have them, see if we can compute them
                # , but need at least 2 spectros ...
                log.info(
                    "Fewer than 6 cameras were available, so couldn't perform joint fit. Exiting ..."
                )
            else:
                flats = []
                for camera in args.cameras:
                    for expid in args.expids:
                        flats.append(
                            findfile('fiberflat', args.night, expid, camera))
                log.info("Number of fiberflat for night {} = {}".format(
                    args.night, len(flats)))
                if len(flats) < 3 * 4 * len(args.cameras):
                    log.info(
                        "Fewer than 3 exposures with 4 lamps were available. Can't perform joint fit. Exiting..."
                    )
                else:
                    log.info(
                        "Computing fiberflatnight per lamp and camera ...")
                    tmpdir = os.path.join(fiberflatdirname, "tmp")
                    if not os.path.isdir(tmpdir):
                        os.makedirs(tmpdir)

                    log.info(
                        "First average measurements per camera and per lamp")
                    average_flats = dict()
                    for camera in args.cameras:
                        # list of flats for this camera
                        flats_for_this_camera = []
                        for flat in flats:
                            if flat.find(camera) >= 0:
                                flats_for_this_camera.append(flat)
                        # log.info("For camera {} , flats = {}".format(camera,flats_for_this_camera))
                        # sys.exit(12)

                        # average per lamp (and camera)
                        average_flats[camera] = list()
                        for lampbox in range(4):
                            ofile = os.path.join(
                                tmpdir,
                                "fiberflatnight-camera-{}-lamp-{}.fits".format(
                                    camera, lampbox))
                            if not os.path.isfile(ofile):
                                log.info(
                                    "Average flat for camera {} and lamp box #{}"
                                    .format(camera, lampbox))
                                pg = "CALIB DESI-CALIB-0{} LEDs only".format(
                                    lampbox)

                                cmd = "desi_average_fiberflat --program '{}' --outfile {} -i ".format(
                                    pg, ofile)
                                for flat in flats_for_this_camera:
                                    cmd += " {} ".format(flat)
                                num_cmd += 1
                                err = runcmd(cmd,
                                             inputs=flats_for_this_camera,
                                             outputs=[
                                                 ofile,
                                             ])
                                if err:
                                    num_err += 1
                                if os.path.isfile(ofile):
                                    average_flats[camera].append(ofile)
                                else:
                                    log.error(
                                        f"Generating {ofile} failed; proceeding with other flats"
                                    )
                            else:
                                log.info("Will use existing {}".format(ofile))
                                average_flats[camera].append(ofile)

                    log.info(
                        "Auto-calibration across lamps and spectro  per camera arm (b,r,z)"
                    )
                    for camera_arm in ["b", "r", "z"]:
                        cameras_for_this_arm = []
                        flats_for_this_arm = []
                        for camera in args.cameras:
                            if camera[0].lower() == camera_arm:
                                cameras_for_this_arm.append(camera)
                                if camera in average_flats:
                                    for flat in average_flats[camera]:
                                        flats_for_this_arm.append(flat)
                        if len(flats_for_this_arm) > 0:
                            cmd = "desi_autocalib_fiberflat --night {} --arm {} -i ".format(
                                args.night, camera_arm)
                            for flat in flats_for_this_arm:
                                cmd += " {} ".format(flat)
                            num_cmd += 1
                            err = runcmd(cmd,
                                         inputs=flats_for_this_arm,
                                         outputs=[])
                            if err:
                                num_err += 1
                        else:
                            log.error(f'No flats found for arm {camera_arm}')

                    log.info("Done with fiber flats per night")

        timer.stop('fiberflatnight')
        num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm)
        if comm is not None:
            comm.barrier()

        if rank == 0:
            if num_err > 0:
                log.error(f'{num_err}/{num_cmd} fiberflat commands failed')

        if num_err > 0 and num_err == num_cmd:
            if rank == 0:
                log.critical('All fiberflat commands failed')
            sys.exit(1)

    ##################### Note #############################
    ### Still for single exposure. Needs to be re-factored #
    ########################################################

    if args.obstype in ['SCIENCE']:
        #inputfile = findfile('raw', night=args.night, expid=args.expids[0])
        #if not os.path.isfile(inputfile):
        #    raise IOError('Missing input file: {}'.format(inputfile))
        ## - Fill in values from raw data header if not overridden by command line
        #fx = fitsio.FITS(inputfile)
        #if 'SPEC' in fx:  # - 20200225 onwards
        #    # hdr = fits.getheader(args.input, 'SPEC')
        #    hdr = fx['SPEC'].read_header()
        #elif 'SPS' in fx:  # - 20200224 and before
        #    # hdr = fits.getheader(args.input, 'SPS')
        #    hdr = fx['SPS'].read_header()
        #else:
        #    # hdr = fits.getheader(args.input, 0)
        #    hdr = fx[0].read_header()
        #
        #camhdr = dict()
        #for cam in args.cameras:
        #    camhdr[cam] = fx[cam].read_header()

        #fx.close()

        timer.start('stdstarfit')
        num_err = num_cmd = 0
        if rank == 0:
            log.info('Starting stdstar fitting at {}'.format(time.asctime()))

        # -------------------------------------------------------------------------
        # - Get input fiberflat
        input_fiberflat = dict()
        if rank == 0:
            for camera in args.cameras:
                if args.fiberflat is not None:
                    input_fiberflat[camera] = args.fiberflat
                elif args.calibnight is not None:
                    # look for a fiberflatnight for this calib night
                    fiberflatnightfile = findfile('fiberflatnight',
                                                  args.calibnight,
                                                  args.expids[0], camera)
                    if not os.path.isfile(fiberflatnightfile):
                        log.error("no {}".format(fiberflatnightfile))
                        raise IOError("no {}".format(fiberflatnightfile))
                    input_fiberflat[camera] = fiberflatnightfile
                else:
                    # look for a fiberflatnight fiberflat
                    fiberflatnightfile = findfile('fiberflatnight', args.night,
                                                  args.expids[0], camera)
                if os.path.isfile(fiberflatnightfile):
                    input_fiberflat[camera] = fiberflatnightfile
                elif args.most_recent_calib:
                    # -- NOTE: Finding most recent only with respect to the first night
                    nightfile = find_most_recent(args.night,
                                                 file_type='fiberflatnight')
                    if nightfile is None:
                        input_fiberflat[camera] = findcalibfile(
                            [hdr, camhdr[camera]], 'FIBERFLAT')
                    else:
                        input_fiberflat[camera] = nightfile
                else:
                    input_fiberflat[camera] = findcalibfile(
                        [hdr, camhdr[camera]], 'FIBERFLAT')
            log.info("Will use input FIBERFLAT: {}".format(
                input_fiberflat[camera]))

        if comm is not None:
            input_fiberflat = comm.bcast(input_fiberflat, root=0)

        # - Group inputs by spectrograph
        framefiles = dict()
        skyfiles = dict()
        fiberflatfiles = dict()
        for camera in args.cameras:
            sp = int(camera[1])
            if sp not in framefiles:
                framefiles[sp] = list()
                skyfiles[sp] = list()
                fiberflatfiles[sp] = list()

            fiberflatfiles[sp].append(input_fiberflat[camera])
            for expid in args.expids:
                framefiles[sp].append(
                    findfile('frame', args.night, expid, camera))
                skyfiles[sp].append(findfile('sky', args.night, expid, camera))

        # - Hardcoded stdstar model version
        starmodels = os.path.join(os.getenv('DESI_BASIS_TEMPLATES'),
                                  'stdstar_templates_v2.2.fits')

        # - Fit stdstars per spectrograph (not per-camera)
        spectro_nums = sorted(framefiles.keys())
        ## for sp in spectro_nums[rank::size]:
        for i in range(rank, len(spectro_nums), size):
            sp = spectro_nums[i]
            # - NOTE: Saving the joint fit file with only the name of the first exposure
            stdfile = findfile('stdstars',
                               args.night,
                               args.expids[0],
                               spectrograph=sp)
            #stdfile.replace('{:08d}'.format(args.expids[0]),'-'.join(['{:08d}'.format(eid) for eid in args.expids]))
            cmd = "desi_fit_stdstars"
            cmd += " --delta-color 0.1"
            cmd += " --frames {}".format(' '.join(framefiles[sp]))
            cmd += " --skymodels {}".format(' '.join(skyfiles[sp]))
            cmd += " --fiberflats {}".format(' '.join(fiberflatfiles[sp]))
            cmd += " --starmodels {}".format(starmodels)
            cmd += " --outfile {}".format(stdfile)
            if args.maxstdstars is not None:
                cmd += " --maxstdstars {}".format(args.maxstdstars)

            inputs = framefiles[sp] + skyfiles[sp] + fiberflatfiles[sp]
            num_cmd += 1
            err = runcmd(cmd, inputs=inputs, outputs=[stdfile])
            if err:
                num_err += 1

        timer.stop('stdstarfit')
        num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm)
        if comm is not None:
            comm.barrier()

        if rank == 0 and num_err > 0:
            log.error(f'{num_err}/{num_cmd} stdstar commands failed')

        sys.stdout.flush()
        if num_err > 0 and num_err == num_cmd:
            if rank == 0:
                log.critical('All stdstar commands failed')
            sys.exit(1)

        if rank == 0 and len(args.expids) > 1:
            for sp in spectro_nums:
                saved_stdfile = findfile('stdstars',
                                         args.night,
                                         args.expids[0],
                                         spectrograph=sp)
                for expid in args.expids[1:]:
                    new_stdfile = findfile('stdstars',
                                           args.night,
                                           expid,
                                           spectrograph=sp)
                    new_dirname, new_fname = os.path.split(new_stdfile)
                    log.debug(
                        "Path exists: {}, file exists: {}, link exists: {}".
                        format(os.path.exists(new_stdfile),
                               os.path.isfile(new_stdfile),
                               os.path.islink(new_stdfile)))
                    relpath_saved_std = os.path.relpath(
                        saved_stdfile, new_dirname)
                    log.debug(f'Sym Linking jointly fitted stdstar file: {new_stdfile} '+\
                            f'to existing file at rel. path {relpath_saved_std}')
                    runcmd(os.symlink, args=(relpath_saved_std, new_stdfile), \
                        inputs=[saved_stdfile, ], outputs=[new_stdfile, ])
                    log.debug(
                        "Path exists: {}, file exists: {}, link exists: {}".
                        format(os.path.exists(new_stdfile),
                               os.path.isfile(new_stdfile),
                               os.path.islink(new_stdfile)))

    # -------------------------------------------------------------------------
    # - Wrap up

    # if rank == 0:
    #     report = timer.report()
    #     log.info('Rank 0 timing report:\n' + report)

    if comm is not None:
        timers = comm.gather(timer, root=0)
    else:
        timers = [
            timer,
        ]

    if rank == 0:
        stats = desiutil.timer.compute_stats(timers)
        log.info('Timing summary statistics:\n' + json.dumps(stats, indent=2))

        if args.timingfile:
            if os.path.exists(args.timingfile):
                with open(args.timingfile) as fx:
                    previous_stats = json.load(fx)

                #- augment previous_stats with new entries, but don't overwrite old
                for name in stats:
                    if name not in previous_stats:
                        previous_stats[name] = stats[name]

                stats = previous_stats

            tmpfile = args.timingfile + '.tmp'
            with open(tmpfile, 'w') as fx:
                json.dump(stats, fx, indent=2)
            os.rename(tmpfile, args.timingfile)

    if rank == 0:
        log.info('All done at {}'.format(time.asctime()))
Пример #10
0
def main(args, comm=None):

    log = get_logger()

    imgfile = args.input_image
    outfile = args.output_psf

    if args.input_psf is not None:
        inpsffile = args.input_psf
    else:
        from desispec.calibfinder import findcalibfile
        hdr = fits.getheader(imgfile)
        inpsffile = findcalibfile([hdr,], 'PSF')

    optarray = []
    if args.extra is not None:
        optarray = args.extra.split()

    specmin = int(args.specmin)
    nspec = int(args.nspec)
    bundlesize = int(args.bundlesize)

    specmax = specmin + nspec

    # Now we divide our spectra into bundles

    checkbundles = set()
    checkbundles.update(np.floor_divide(np.arange(specmin, specmax),
        bundlesize*np.ones(nspec)).astype(int))
    bundles = sorted(checkbundles)
    nbundle = len(bundles)

    bspecmin = {}
    bnspec = {}
    for b in bundles:
        if specmin > b * bundlesize:
            bspecmin[b] = specmin
        else:
            bspecmin[b] = b * bundlesize
        if (b+1) * bundlesize > specmax:
            bnspec[b] = specmax - bspecmin[b]
        else:
            bnspec[b] = (b+1) * bundlesize - bspecmin[b]

    # Now we assign bundles to processes

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    mynbundle = int(nbundle / nproc)
    myfirstbundle = 0
    leftover = nbundle % nproc
    if rank < leftover:
        mynbundle += 1
        myfirstbundle = rank * mynbundle
    else:
        myfirstbundle = ((mynbundle + 1) * leftover) + \
            (mynbundle * (rank - leftover))

    if rank == 0:
        # Print parameters
        log.info("specex:  using {} processes".format(nproc))
        log.info("specex:  input image = {}".format(imgfile))
        log.info("specex:  input PSF = {}".format(inpsffile))
        log.info("specex:  output = {}".format(outfile))
        log.info("specex:  bundlesize = {}".format(bundlesize))
        log.info("specex:  specmin = {}".format(specmin))
        log.info("specex:  specmax = {}".format(specmax))

    # get the root output file

    outpat = re.compile(r'(.*)\.fits')
    outmat = outpat.match(outfile)
    if outmat is None:
        raise RuntimeError("specex output file should have .fits extension")
    outroot = outmat.group(1)

    outdir = os.path.dirname(outroot)
    if rank == 0:
        if not os.path.isdir(outdir):
            os.makedirs(outdir)

    failcount = 0

    for b in range(myfirstbundle, myfirstbundle+mynbundle):
        outbundle = "{}_{:02d}".format(outroot, b)
        outbundlefits = "{}.fits".format(outbundle)
        com = ['desi_psf_fit']
        com.extend(['-a', imgfile])
        com.extend(['--in-psf', inpsffile])
        com.extend(['--out-psf', outbundlefits])
        com.extend(['--first-bundle', "{}".format(b)])
        com.extend(['--last-bundle', "{}".format(b)])
        com.extend(['--first-fiber', "{}".format(bspecmin[b])])
        com.extend(['--last-fiber', "{}".format(bspecmin[b]+bnspec[b]-1)])
        if args.debug :
            com.extend(['--debug'])

        com.extend(optarray)

        log.debug("proc {} calling {}".format(rank, " ".join(com)))

        argc = len(com)
        arg_buffers = [ct.create_string_buffer(com[i].encode('ascii')) \
            for i in range(argc)]
        addrlist = [ ct.cast(x, ct.POINTER(ct.c_char)) for x in \
            map(ct.addressof, arg_buffers) ]
        arg_pointers = (ct.POINTER(ct.c_char) * argc)(*addrlist)

        retval = libspecex.cspecex_desi_psf_fit(argc, arg_pointers)

        if retval != 0:
            comstr = " ".join(com)
            log.error("desi_psf_fit on process {} failed with return "
                "value {} running {}".format(rank, retval, comstr))
            failcount += 1

    if comm is not None:
        from mpi4py import MPI
        failcount = comm.allreduce(failcount, op=MPI.SUM)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("some bundles failed desi_psf_fit")

    if rank == 0:
        outfits = "{}.fits".format(outroot)

        inputs = [ "{}_{:02d}.fits".format(outroot, x) for x in bundles ]

        merge_psf(inputs,outfits)

        if failcount == 0:
            # only remove the per-bundle files if the merge was good
            for f in inputs :
                if os.path.isfile(f):
                    os.remove(f)

    if comm is not None:
        failcount = comm.bcast(failcount, root=0)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("merging of per-bundle files failed")

    return
Пример #11
0
def main(args=None, comm=None):
    if args is None:
        args = parse()
    # elif isinstance(args, (list, tuple)):
    #     args = parse(args)

    log = get_logger()

    start_mpi_connect = time.time()
    if comm is not None:
        #- Use the provided comm to determine rank and size
        rank = comm.rank
        size = comm.size
    else:
        #- Check MPI flags and determine the comm, rank, and size given the arguments
        comm, rank, size = assign_mpi(do_mpi=args.mpi,
                                      do_batch=args.batch,
                                      log=log)
    stop_mpi_connect = time.time()

    #- Start timer; only print log messages from rank 0 (others are silent)
    timer = desiutil.timer.Timer(silent=(rank > 0))

    #- Fill in timing information for steps before we had the timer created
    if args.starttime is not None:
        timer.start('startup', starttime=args.starttime)
        timer.stop('startup', stoptime=start_imports)

    timer.start('imports', starttime=start_imports)
    timer.stop('imports', stoptime=stop_imports)

    timer.start('mpi_connect', starttime=start_mpi_connect)
    timer.stop('mpi_connect', stoptime=stop_mpi_connect)

    #- Freeze IERS after parsing args so that it doesn't bother if only --help
    timer.start('freeze_iers')
    desiutil.iers.freeze_iers()
    timer.stop('freeze_iers')

    #- Preflight checks
    timer.start('preflight')
    if rank > 0:
        #- Let rank 0 fetch these, and then broadcast
        args, hdr, camhdr = None, None, None
    else:
        args, hdr, camhdr = update_args_with_headers(args)

    ## Make sure badamps is formatted properly
    if comm is not None and rank == 0 and args.badamps is not None:
        args.badamps = validate_badamps(args.badamps)

    if comm is not None:
        args = comm.bcast(args, root=0)
        hdr = comm.bcast(hdr, root=0)
        camhdr = comm.bcast(camhdr, root=0)

    known_obstype = [
        'SCIENCE', 'ARC', 'FLAT', 'ZERO', 'DARK', 'TESTARC', 'TESTFLAT',
        'PIXFLAT', 'SKY', 'TWILIGHT', 'OTHER'
    ]
    if args.obstype not in known_obstype:
        raise RuntimeError('obstype {} not in {}'.format(
            args.obstype, known_obstype))

    timer.stop('preflight')

    #-------------------------------------------------------------------------
    #- Create and submit a batch job if requested

    if args.batch:
        #exp_str = '{:08d}'.format(args.expid)
        jobdesc = args.obstype.lower()
        if args.obstype == 'SCIENCE':
            # if not doing pre-stdstar fitting or stdstar fitting and if there is
            # no flag stopping flux calibration, set job to poststdstar
            if args.noprestdstarfit and args.nostdstarfit and (
                    not args.nofluxcalib):
                jobdesc = 'poststdstar'
            # elif told not to do std or post stdstar but the flag for prestdstar isn't set,
            # then perform prestdstar
            elif (not args.noprestdstarfit
                  ) and args.nostdstarfit and args.nofluxcalib:
                jobdesc = 'prestdstar'
            #elif (not args.noprestdstarfit) and (not args.nostdstarfit) and (not args.nofluxcalib):
            #    jobdesc = 'science'
        scriptfile = create_desi_proc_batch_script(night=args.night, exp=args.expid, cameras=args.cameras,\
                                                jobdesc=jobdesc, queue=args.queue, runtime=args.runtime,\
                                                batch_opts=args.batch_opts, timingfile=args.timingfile,
                                                system_name=args.system_name)
        err = 0
        if not args.nosubmit:
            err = subprocess.call(['sbatch', scriptfile])
        sys.exit(err)

    #-------------------------------------------------------------------------
    #- Proceeding with running

    #- What are we going to do?
    if rank == 0:
        log.info('----------')
        log.info('Input {}'.format(args.input))
        log.info('Night {} expid {}'.format(args.night, args.expid))
        log.info('Obstype {}'.format(args.obstype))
        log.info('Cameras {}'.format(args.cameras))
        log.info('Output root {}'.format(desispec.io.specprod_root()))
        log.info('----------')

    #- Create output directories if needed
    if rank == 0:
        preprocdir = os.path.dirname(
            findfile('preproc', args.night, args.expid, 'b0'))
        expdir = os.path.dirname(
            findfile('frame', args.night, args.expid, 'b0'))
        os.makedirs(preprocdir, exist_ok=True)
        os.makedirs(expdir, exist_ok=True)

    #- Wait for rank 0 to make directories before proceeding
    if comm is not None:
        comm.barrier()

    #-------------------------------------------------------------------------
    #- Preproc
    #- All obstypes get preprocessed

    timer.start('fibermap')

    #- Assemble fibermap for science exposures
    fibermap = None
    fibermap_ok = None
    if rank == 0 and args.obstype == 'SCIENCE':
        fibermap = findfile('fibermap', args.night, args.expid)
        if not os.path.exists(fibermap):
            tmp = findfile('preproc', args.night, args.expid, 'b0')
            preprocdir = os.path.dirname(tmp)
            fibermap = os.path.join(preprocdir, os.path.basename(fibermap))

            log.info('Creating fibermap {}'.format(fibermap))
            cmd = 'assemble_fibermap -n {} -e {} -o {}'.format(
                args.night, args.expid, fibermap)
            if args.badamps is not None:
                cmd += ' --badamps={}'.format(args.badamps)
            runcmd(cmd, inputs=[], outputs=[fibermap])

        fibermap_ok = os.path.exists(fibermap)

        #- Some commissioning files didn't have coords* files that caused assemble_fibermap to fail
        #- these are well known failures with no other solution, so for those, just force creation
        #- of a fibermap with null coordinate information
        if not fibermap_ok and int(args.night) < 20200310:
            log.info(
                "Since night is before 20200310, trying to force fibermap creation without coords file"
            )
            cmd += ' --force'
            runcmd(cmd, inputs=[], outputs=[fibermap])
            fibermap_ok = os.path.exists(fibermap)

    #- If assemble_fibermap failed and obstype is SCIENCE, exit now
    if comm is not None:
        fibermap_ok = comm.bcast(fibermap_ok, root=0)

    if args.obstype == 'SCIENCE' and not fibermap_ok:
        sys.stdout.flush()
        if rank == 0:
            log.critical(
                'assemble_fibermap failed for science exposure; exiting now')

        sys.exit(13)

    #- Wait for rank 0 to make fibermap if needed
    if comm is not None:
        fibermap = comm.bcast(fibermap, root=0)

    timer.stop('fibermap')

    if not (args.obstype in ['SCIENCE'] and args.noprestdstarfit):
        timer.start('preproc')
        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            outfile = findfile('preproc', args.night, args.expid, camera)
            outdir = os.path.dirname(outfile)
            cmd = "desi_preproc -i {} -o {} --outdir {} --cameras {}".format(
                args.input, outfile, outdir, camera)
            if args.scattered_light:
                cmd += " --scattered-light"
            if fibermap is not None:
                cmd += " --fibermap {}".format(fibermap)
            if not args.obstype in ['ARC']:  # never model variance for arcs
                if not args.no_model_pixel_variance:
                    cmd += " --model-variance"
            runcmd(cmd, inputs=[args.input], outputs=[outfile])

        timer.stop('preproc')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Get input PSFs
    timer.start('findpsf')
    input_psf = dict()
    if rank == 0:
        for camera in args.cameras:
            if args.psf is not None:
                input_psf[camera] = args.psf
            elif args.calibnight is not None:
                # look for a psfnight psf for this calib night
                psfnightfile = findfile('psfnight', args.calibnight,
                                        args.expid, camera)
                if not os.path.isfile(psfnightfile):
                    log.error("no {}".format(psfnightfile))
                    raise IOError("no {}".format(psfnightfile))
                input_psf[camera] = psfnightfile
            else:
                # look for a psfnight psf
                psfnightfile = findfile('psfnight', args.night, args.expid,
                                        camera)
                if os.path.isfile(psfnightfile):
                    input_psf[camera] = psfnightfile
                elif args.most_recent_calib:
                    nightfile = find_most_recent(args.night,
                                                 file_type='psfnight')
                    if nightfile is None:
                        input_psf[camera] = findcalibfile(
                            [hdr, camhdr[camera]], 'PSF')
                    else:
                        input_psf[camera] = nightfile
                else:
                    input_psf[camera] = findcalibfile([hdr, camhdr[camera]],
                                                      'PSF')
            log.info("Will use input PSF : {}".format(input_psf[camera]))

    if comm is not None:
        input_psf = comm.bcast(input_psf, root=0)

    timer.stop('findpsf')

    #-------------------------------------------------------------------------
    #- Traceshift

    if ( args.obstype in ['FLAT', 'TESTFLAT', 'SKY', 'TWILIGHT']     )   or \
    ( args.obstype in ['SCIENCE'] and (not args.noprestdstarfit) ):

        timer.start('traceshift')

        if rank == 0 and args.traceshift:
            log.info('Starting traceshift at {}'.format(time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            preprocfile = findfile('preproc', args.night, args.expid, camera)
            inpsf = input_psf[camera]
            outpsf = findfile('psf', args.night, args.expid, camera)
            if not os.path.isfile(outpsf):
                if args.traceshift:
                    cmd = "desi_compute_trace_shifts"
                    cmd += " -i {}".format(preprocfile)
                    cmd += " --psf {}".format(inpsf)
                    cmd += " --outpsf {}".format(outpsf)
                    cmd += " --degxx 2 --degxy 0"
                    if args.obstype in ['FLAT', 'TESTFLAT', 'TWILIGHT']:
                        cmd += " --continuum"
                    else:
                        cmd += " --degyx 2 --degyy 0"
                    if args.obstype in ['SCIENCE', 'SKY']:
                        cmd += ' --sky'
                else:
                    cmd = "ln -s {} {}".format(inpsf, outpsf)
                runcmd(cmd, inputs=[preprocfile, inpsf], outputs=[outpsf])
            else:
                log.info("PSF {} exists".format(outpsf))

        timer.stop('traceshift')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- PSF
    #- MPI parallelize this step

    if args.obstype in ['ARC', 'TESTARC']:

        timer.start('arc_traceshift')

        if rank == 0:
            log.info('Starting traceshift before specex PSF fit at {}'.format(
                time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            preprocfile = findfile('preproc', args.night, args.expid, camera)
            inpsf = input_psf[camera]
            outpsf = findfile('psf', args.night, args.expid, camera)
            outpsf = replace_prefix(outpsf, "psf", "shifted-input-psf")
            if not os.path.isfile(outpsf):
                cmd = "desi_compute_trace_shifts"
                cmd += " -i {}".format(preprocfile)
                cmd += " --psf {}".format(inpsf)
                cmd += " --outpsf {}".format(outpsf)
                cmd += " --degxx 0 --degxy 0 --degyx 0 --degyy 0"
                cmd += ' --arc-lamps'
                runcmd(cmd, inputs=[preprocfile, inpsf], outputs=[outpsf])
            else:
                log.info("PSF {} exists".format(outpsf))

        timer.stop('arc_traceshift')
        if comm is not None:
            comm.barrier()

        timer.start('psf')

        if rank == 0:
            log.info('Starting specex PSF fitting at {}'.format(
                time.asctime()))

        if rank > 0:
            cmds = inputs = outputs = None
        else:
            cmds = dict()
            inputs = dict()
            outputs = dict()
            for camera in args.cameras:
                preprocfile = findfile('preproc', args.night, args.expid,
                                       camera)
                tmpname = findfile('psf', args.night, args.expid, camera)
                inpsf = replace_prefix(tmpname, "psf", "shifted-input-psf")
                outpsf = replace_prefix(tmpname, "psf", "fit-psf")

                log.info("now run specex psf fit")

                cmd = 'desi_compute_psf'
                cmd += ' --input-image {}'.format(preprocfile)
                cmd += ' --input-psf {}'.format(inpsf)
                cmd += ' --output-psf {}'.format(outpsf)

                # look for fiber blacklist
                cfinder = CalibFinder([hdr, camhdr[camera]])
                blacklistkey = "FIBERBLACKLIST"
                if not cfinder.haskey(blacklistkey) and cfinder.haskey(
                        "BROKENFIBERS"):
                    log.warning(
                        "BROKENFIBERS yaml keyword deprecated, please use FIBERBLACKLIST"
                    )
                    blacklistkey = "BROKENFIBERS"

                if cfinder.haskey(blacklistkey):
                    blacklist = cfinder.value(blacklistkey)
                    cmd += ' --broken-fibers {}'.format(blacklist)
                    if rank == 0:
                        log.warning('broken fibers: {}'.format(blacklist))

                if not os.path.exists(outpsf):
                    cmds[camera] = cmd
                    inputs[camera] = [preprocfile, inpsf]
                    outputs[camera] = [
                        outpsf,
                    ]

        if comm is not None:
            cmds = comm.bcast(cmds, root=0)
            inputs = comm.bcast(inputs, root=0)
            outputs = comm.bcast(outputs, root=0)
            #- split communicator by 20 (number of bundles)
            group_size = 20
            if (rank == 0) and (size % group_size != 0):
                log.warning(
                    'MPI size={} should be evenly divisible by {}'.format(
                        size, group_size))

            group = rank // group_size
            num_groups = (size + group_size - 1) // group_size
            comm_group = comm.Split(color=group)

            if rank == 0:
                log.info(
                    f'Fitting PSFs with {num_groups} sub-communicators of size {group_size}'
                )

            for i in range(group, len(args.cameras), num_groups):
                camera = args.cameras[i]
                if camera in cmds:
                    cmdargs = cmds[camera].split()[1:]
                    cmdargs = desispec.scripts.specex.parse(cmdargs)
                    if comm_group.rank == 0:
                        print('RUNNING: {}'.format(cmds[camera]))
                        t0 = time.time()
                        timestamp = time.asctime()
                        log.info(
                            f'MPI group {group} ranks {rank}-{rank+group_size-1} fitting PSF for {camera} at {timestamp}'
                        )
                    try:
                        desispec.scripts.specex.main(cmdargs, comm=comm_group)
                    except Exception as e:
                        if comm_group.rank == 0:
                            log.error(
                                f'FAILED: MPI group {group} ranks {rank}-{rank+group_size-1} camera {camera}'
                            )
                            log.error('FAILED: {}'.format(cmds[camera]))
                            log.error(e)

                    if comm_group.rank == 0:
                        specex_time = time.time() - t0
                        log.info(
                            f'specex fit for {camera} took {specex_time:.1f} seconds'
                        )

            comm.barrier()

        else:
            log.warning(
                'fitting PSFs without MPI parallelism; this will be SLOW')
            for camera in args.cameras:
                if camera in cmds:
                    runcmd(cmds[camera],
                           inputs=inputs[camera],
                           outputs=outputs[camera])

        if comm is not None:
            comm.barrier()

        # loop on all cameras and interpolate bad fibers
        for camera in args.cameras[rank::size]:
            t0 = time.time()
            log.info(f'Rank {rank} interpolating {camera} PSF over bad fibers')
            # look for fiber blacklist
            cfinder = CalibFinder([hdr, camhdr[camera]])
            blacklistkey = "FIBERBLACKLIST"
            if not cfinder.haskey(blacklistkey) and cfinder.haskey(
                    "BROKENFIBERS"):
                log.warning(
                    "BROKENFIBERS yaml keyword deprecated, please use FIBERBLACKLIST"
                )
                blacklistkey = "BROKENFIBERS"

            if cfinder.haskey(blacklistkey):
                fiberblacklist = cfinder.value(blacklistkey)
                tmpname = findfile('psf', args.night, args.expid, camera)
                inpsf = replace_prefix(tmpname, "psf", "fit-psf")
                outpsf = replace_prefix(tmpname, "psf",
                                        "fit-psf-fixed-blacklisted")
                if os.path.isfile(inpsf) and not os.path.isfile(outpsf):
                    cmd = 'desi_interpolate_fiber_psf'
                    cmd += ' --infile {}'.format(inpsf)
                    cmd += ' --outfile {}'.format(outpsf)
                    cmd += ' --fibers {}'.format(fiberblacklist)
                    log.info(
                        'For camera {} interpolating PSF for broken fibers: {}'
                        .format(camera, fiberblacklist))
                    runcmd(cmd, inputs=[inpsf], outputs=[outpsf])
                    if os.path.isfile(outpsf):
                        os.rename(
                            inpsf,
                            inpsf.replace("fit-psf",
                                          "fit-psf-before-blacklisted-fix"))
                        subprocess.call('cp {} {}'.format(outpsf, inpsf),
                                        shell=True)

            dt = time.time() - t0
            log.info(
                f'Rank {rank} {camera} PSF interpolation took {dt:.1f} sec')

        timer.stop('psf')

    #-------------------------------------------------------------------------
    #- Merge PSF of night if applicable

    #if args.obstype in ['ARC']:
    if False:
        if rank == 0:
            for camera in args.cameras:
                psfnightfile = findfile('psfnight', args.night, args.expid,
                                        camera)
                if not os.path.isfile(
                        psfnightfile
                ):  # we still don't have a psf night, see if we can compute it ...
                    psfs = glob.glob(
                        findfile('psf', args.night, args.expid,
                                 camera).replace("psf", "fit-psf").replace(
                                     str(args.expid), "*"))
                    log.info(
                        "Number of PSF for night={} camera={} = {}".format(
                            args.night, camera, len(psfs)))
                    if len(psfs) > 4:  # lets do it!
                        log.info("Computing psfnight ...")
                        dirname = os.path.dirname(psfnightfile)
                        if not os.path.isdir(dirname):
                            os.makedirs(dirname)
                        desispec.scripts.specex.mean_psf(psfs, psfnightfile)
                if os.path.isfile(psfnightfile):  # now use this one
                    input_psf[camera] = psfnightfile

    #-------------------------------------------------------------------------
    #- Extract
    #- This is MPI parallel so handle a bit differently

    # maybe add ARC and TESTARC too
    if ( args.obstype in ['FLAT', 'TESTFLAT', 'SKY', 'TWILIGHT']     )   or \
    ( args.obstype in ['SCIENCE'] and (not args.noprestdstarfit) ):

        timer.start('extract')
        if rank == 0:
            log.info('Starting extractions at {}'.format(time.asctime()))

        if rank > 0:
            cmds = inputs = outputs = None
        else:
            cmds = dict()
            inputs = dict()
            outputs = dict()
            for camera in args.cameras:
                cmd = 'desi_extract_spectra'

                #- Based on data from SM1-SM8, looking at central and edge fibers
                #- with in mind overlapping arc lamps lines
                if camera.startswith('b'):
                    cmd += ' -w 3600.0,5800.0,0.8'
                elif camera.startswith('r'):
                    cmd += ' -w 5760.0,7620.0,0.8'
                elif camera.startswith('z'):
                    cmd += ' -w 7520.0,9824.0,0.8'

                preprocfile = findfile('preproc', args.night, args.expid,
                                       camera)
                psffile = findfile('psf', args.night, args.expid, camera)
                framefile = findfile('frame', args.night, args.expid, camera)
                cmd += ' -i {}'.format(preprocfile)
                cmd += ' -p {}'.format(psffile)
                cmd += ' -o {}'.format(framefile)
                cmd += ' --psferr 0.1'

                if args.obstype == 'SCIENCE' or args.obstype == 'SKY':
                    if rank == 0:
                        log.info('Include barycentric correction')
                    cmd += ' --barycentric-correction'

                if not os.path.exists(framefile):
                    cmds[camera] = cmd
                    inputs[camera] = [preprocfile, psffile]
                    outputs[camera] = [
                        framefile,
                    ]

        #- TODO: refactor/combine this with PSF comm splitting logic
        if comm is not None:
            cmds = comm.bcast(cmds, root=0)
            inputs = comm.bcast(inputs, root=0)
            outputs = comm.bcast(outputs, root=0)

            #- split communicator by 20 (number of bundles)
            extract_size = 20
            if (rank == 0) and (size % extract_size != 0):
                log.warning(
                    'MPI size={} should be evenly divisible by {}'.format(
                        size, extract_size))

            extract_group = rank // extract_size
            num_extract_groups = (size + extract_size - 1) // extract_size
            comm_extract = comm.Split(color=extract_group)

            for i in range(extract_group, len(args.cameras),
                           num_extract_groups):
                camera = args.cameras[i]
                if camera in cmds:
                    cmdargs = cmds[camera].split()[1:]
                    extract_args = desispec.scripts.extract.parse(cmdargs)
                    if comm_extract.rank == 0:
                        print('RUNNING: {}'.format(cmds[camera]))

                    desispec.scripts.extract.main_mpi(extract_args,
                                                      comm=comm_extract)

            comm.barrier()

        else:
            log.warning(
                'running extractions without MPI parallelism; this will be SLOW'
            )
            for camera in args.cameras:
                if camera in cmds:
                    runcmd(cmds[camera],
                           inputs=inputs[camera],
                           outputs=outputs[camera])

        timer.stop('extract')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Fiberflat

    if args.obstype in ['FLAT', 'TESTFLAT']:
        timer.start('fiberflat')
        if rank == 0:
            log.info('Starting fiberflats at {}'.format(time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            framefile = findfile('frame', args.night, args.expid, camera)
            fiberflatfile = findfile('fiberflat', args.night, args.expid,
                                     camera)
            cmd = "desi_compute_fiberflat"
            cmd += " -i {}".format(framefile)
            cmd += " -o {}".format(fiberflatfile)
            runcmd(cmd, inputs=[
                framefile,
            ], outputs=[
                fiberflatfile,
            ])

        timer.stop('fiberflat')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Average and auto-calib fiberflats of night if applicable

    #if args.obstype in ['FLAT']:
    if False:
        if rank == 0:
            fiberflatnightfile = findfile('fiberflatnight', args.night,
                                          args.expid, args.cameras[0])
            fiberflatdirname = os.path.dirname(fiberflatnightfile)
            if not os.path.isfile(fiberflatnightfile) and len(
                    args.cameras
            ) >= 6:  # we still don't have them, see if we can compute them, but need at least 2 spectros ...
                flats = glob.glob(
                    findfile('fiberflat', args.night, args.expid,
                             "b0").replace(str(args.expid),
                                           "*").replace("b0", "*"))
                log.info("Number of fiberflat for night {} = {}".format(
                    args.night, len(flats)))
                if len(flats) >= 3 * 4 * len(
                        args.cameras
                ):  # lets do it! (3 exposures x 4 lamps x N cameras)
                    log.info(
                        "Computing fiberflatnight per lamp and camera ...")
                    tmpdir = os.path.join(fiberflatdirname, "tmp")
                    if not os.path.isdir(tmpdir):
                        os.makedirs(tmpdir)

                    log.info(
                        "First average measurements per camera and per lamp")
                    average_flats = dict()
                    for camera in args.cameras:
                        # list of flats for this camera
                        flats_for_this_camera = []
                        for flat in flats:
                            if flat.find(camera) >= 0:
                                flats_for_this_camera.append(flat)
                        #log.info("For camera {} , flats = {}".format(camera,flats_for_this_camera))
                        #sys.exit(12)

                        # average per lamp (and camera)
                        average_flats[camera] = list()
                        for lampbox in range(4):
                            ofile = os.path.join(
                                tmpdir,
                                "fiberflatnight-camera-{}-lamp-{}.fits".format(
                                    camera, lampbox))
                            if not os.path.isfile(ofile):
                                log.info(
                                    "Average flat for camera {} and lamp box #{}"
                                    .format(camera, lampbox))
                                pg = "CALIB DESI-CALIB-0{} LEDs only".format(
                                    lampbox)

                                cmd = "desi_average_fiberflat --program '{}' --outfile {} -i ".format(
                                    pg, ofile)
                                for flat in flats_for_this_camera:
                                    cmd += " {} ".format(flat)
                                runcmd(cmd,
                                       inputs=flats_for_this_camera,
                                       outputs=[
                                           ofile,
                                       ])
                                if os.path.isfile(ofile):
                                    average_flats[camera].append(ofile)
                            else:
                                log.info("Will use existing {}".format(ofile))
                                average_flats[camera].append(ofile)

                    log.info(
                        "Auto-calibration across lamps and spectro  per camera arm (b,r,z)"
                    )
                    for camera_arm in ["b", "r", "z"]:
                        cameras_for_this_arm = []
                        flats_for_this_arm = []
                        for camera in args.cameras:
                            if camera[0].lower() == camera_arm:
                                cameras_for_this_arm.append(camera)
                                if camera in average_flats:
                                    for flat in average_flats[camera]:
                                        flats_for_this_arm.append(flat)
                        cmd = "desi_autocalib_fiberflat --night {} --arm {} -i ".format(
                            args.night, camera_arm)
                        for flat in flats_for_this_arm:
                            cmd += " {} ".format(flat)
                        runcmd(cmd, inputs=flats_for_this_arm, outputs=[])
                    log.info("Done with fiber flats per night")

        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Get input fiberflat
    if args.obstype in ['SCIENCE', 'SKY'] and (not args.nofiberflat):
        timer.start('find_fiberflat')
        input_fiberflat = dict()
        if rank == 0:
            for camera in args.cameras:
                if args.fiberflat is not None:
                    input_fiberflat[camera] = args.fiberflat
                elif args.calibnight is not None:
                    # look for a fiberflatnight for this calib night
                    fiberflatnightfile = findfile('fiberflatnight',
                                                  args.calibnight, args.expid,
                                                  camera)
                    if not os.path.isfile(fiberflatnightfile):
                        log.error("no {}".format(fiberflatnightfile))
                        raise IOError("no {}".format(fiberflatnightfile))
                    input_fiberflat[camera] = fiberflatnightfile
                else:
                    # look for a fiberflatnight fiberflat
                    fiberflatnightfile = findfile('fiberflatnight', args.night,
                                                  args.expid, camera)
                    if os.path.isfile(fiberflatnightfile):
                        input_fiberflat[camera] = fiberflatnightfile
                    elif args.most_recent_calib:
                        nightfile = find_most_recent(
                            args.night, file_type='fiberflatnight')
                        if nightfile is None:
                            input_fiberflat[camera] = findcalibfile(
                                [hdr, camhdr[camera]], 'FIBERFLAT')
                        else:
                            input_fiberflat[camera] = nightfile
                    else:
                        input_fiberflat[camera] = findcalibfile(
                            [hdr, camhdr[camera]], 'FIBERFLAT')
                log.info("Will use input FIBERFLAT: {}".format(
                    input_fiberflat[camera]))

        if comm is not None:
            input_fiberflat = comm.bcast(input_fiberflat, root=0)

        timer.stop('find_fiberflat')

    #-------------------------------------------------------------------------
    #- Apply fiberflat and write fframe file

    if args.obstype in ['SCIENCE', 'SKY'] and args.fframe and \
    ( not args.nofiberflat ) and (not args.noprestdstarfit):
        timer.start('apply_fiberflat')
        if rank == 0:
            log.info('Applying fiberflat at {}'.format(time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            fframefile = findfile('fframe', args.night, args.expid, camera)
            if not os.path.exists(fframefile):
                framefile = findfile('frame', args.night, args.expid, camera)
                fr = desispec.io.read_frame(framefile)
                flatfilename = input_fiberflat[camera]
                if flatfilename is not None:
                    ff = desispec.io.read_fiberflat(flatfilename)
                    fr.meta['FIBERFLT'] = desispec.io.shorten_filename(
                        flatfilename)
                    apply_fiberflat(fr, ff)

                    fframefile = findfile('fframe', args.night, args.expid,
                                          camera)
                    desispec.io.write_frame(fframefile, fr)
                else:
                    log.warning(
                        "Missing fiberflat for camera {}".format(camera))

        timer.stop('apply_fiberflat')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Select random sky fibers (inplace update of frame file)
    #- TODO: move this to a function somewhere
    #- TODO: this assigns different sky fibers to each frame of same spectrograph

    if (args.obstype in [
            'SKY', 'SCIENCE'
    ]) and (not args.noskysub) and (not args.noprestdstarfit):
        timer.start('picksky')
        if rank == 0:
            log.info('Picking sky fibers at {}'.format(time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            framefile = findfile('frame', args.night, args.expid, camera)
            orig_frame = desispec.io.read_frame(framefile)

            #- Make a copy so that we can apply fiberflat
            fr = deepcopy(orig_frame)

            if np.any(fr.fibermap['OBJTYPE'] == 'SKY'):
                log.info('{} sky fibers already set; skipping'.format(
                    os.path.basename(framefile)))
                continue

            #- Apply fiberflat then select random fibers below a flux cut
            flatfilename = input_fiberflat[camera]
            if flatfilename is None:
                log.error("No fiberflat for {}".format(camera))
                continue
            ff = desispec.io.read_fiberflat(flatfilename)
            apply_fiberflat(fr, ff)
            sumflux = np.sum(fr.flux, axis=1)
            fluxcut = np.percentile(sumflux, 30)
            iisky = np.where(sumflux < fluxcut)[0]
            iisky = np.random.choice(iisky, size=100, replace=False)

            #- Update fibermap or original frame and write out
            orig_frame.fibermap['OBJTYPE'][iisky] = 'SKY'
            orig_frame.fibermap['DESI_TARGET'][iisky] |= desi_mask.SKY

            desispec.io.write_frame(framefile, orig_frame)

        timer.stop('picksky')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Sky subtraction
    if args.obstype in [
            'SCIENCE', 'SKY'
    ] and (not args.noskysub) and (not args.noprestdstarfit):
        timer.start('skysub')
        if rank == 0:
            log.info('Starting sky subtraction at {}'.format(time.asctime()))

        for i in range(rank, len(args.cameras), size):
            camera = args.cameras[i]
            framefile = findfile('frame', args.night, args.expid, camera)
            hdr = fitsio.read_header(framefile, 'FLUX')
            fiberflatfile = input_fiberflat[camera]
            if fiberflatfile is None:
                log.error("No fiberflat for {}".format(camera))
                continue
            skyfile = findfile('sky', args.night, args.expid, camera)

            cmd = "desi_compute_sky"
            cmd += " -i {}".format(framefile)
            cmd += " --fiberflat {}".format(fiberflatfile)
            cmd += " --o {}".format(skyfile)
            if args.no_extra_variance:
                cmd += " --no-extra-variance"
            if not args.no_sky_wavelength_adjustment:
                cmd += " --adjust-wavelength"
            if not args.no_sky_lsf_adjustment: cmd += " --adjust-lsf"

            runcmd(cmd, inputs=[framefile, fiberflatfile], outputs=[
                skyfile,
            ])

            #- sframe = flatfielded sky-subtracted but not flux calibrated frame
            #- Note: this re-reads and re-does steps previously done for picking
            #- sky fibers; desi_proc is about human efficiency,
            #- not I/O or CPU efficiency...
            sframefile = desispec.io.findfile('sframe', args.night, args.expid,
                                              camera)
            if not os.path.exists(sframefile):
                frame = desispec.io.read_frame(framefile)
                fiberflat = desispec.io.read_fiberflat(fiberflatfile)
                sky = desispec.io.read_sky(skyfile)
                apply_fiberflat(frame, fiberflat)
                subtract_sky(frame, sky, apply_throughput_correction=True)
                frame.meta['IN_SKY'] = shorten_filename(skyfile)
                frame.meta['FIBERFLT'] = shorten_filename(fiberflatfile)
                desispec.io.write_frame(sframefile, frame)

        timer.stop('skysub')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Standard Star Fitting

    if args.obstype in ['SCIENCE',] and \
            (not args.noskysub ) and \
            (not args.nostdstarfit) :

        timer.start('stdstarfit')
        if rank == 0:
            log.info('Starting flux calibration at {}'.format(time.asctime()))

        #- Group inputs by spectrograph
        framefiles = dict()
        skyfiles = dict()
        fiberflatfiles = dict()
        night, expid = args.night, args.expid  #- shorter
        for camera in args.cameras:
            sp = int(camera[1])
            if sp not in framefiles:
                framefiles[sp] = list()
                skyfiles[sp] = list()
                fiberflatfiles[sp] = list()

            framefiles[sp].append(findfile('frame', night, expid, camera))
            skyfiles[sp].append(findfile('sky', night, expid, camera))
            fiberflatfiles[sp].append(input_fiberflat[camera])

        #- Hardcoded stdstar model version
        starmodels = os.path.join(os.getenv('DESI_BASIS_TEMPLATES'),
                                  'stdstar_templates_v2.2.fits')

        #- Fit stdstars per spectrograph (not per-camera)
        spectro_nums = sorted(framefiles.keys())
        ## for sp in spectro_nums[rank::size]:
        for i in range(rank, len(spectro_nums), size):
            sp = spectro_nums[i]

            stdfile = findfile('stdstars', night, expid, spectrograph=sp)
            cmd = "desi_fit_stdstars"
            cmd += " --frames {}".format(' '.join(framefiles[sp]))
            cmd += " --skymodels {}".format(' '.join(skyfiles[sp]))
            cmd += " --fiberflats {}".format(' '.join(fiberflatfiles[sp]))
            cmd += " --starmodels {}".format(starmodels)
            cmd += " --outfile {}".format(stdfile)
            cmd += " --delta-color 0.1"
            if args.maxstdstars is not None:
                cmd += " --maxstdstars {}".format(args.maxstdstars)

            inputs = framefiles[sp] + skyfiles[sp] + fiberflatfiles[sp]
            runcmd(cmd, inputs=inputs, outputs=[stdfile])

        timer.stop('stdstarfit')
        if comm is not None:
            comm.barrier()

    # -------------------------------------------------------------------------
    # - Flux calibration

    if args.obstype in ['SCIENCE'] and \
                (not args.noskysub) and \
                (not args.nofluxcalib):
        timer.start('fluxcalib')

        night, expid = args.night, args.expid  #- shorter
        #- Compute flux calibration vectors per camera
        for camera in args.cameras[rank::size]:
            framefile = findfile('frame', night, expid, camera)
            skyfile = findfile('sky', night, expid, camera)
            spectrograph = int(camera[1])
            stdfile = findfile('stdstars',
                               night,
                               expid,
                               spectrograph=spectrograph)
            calibfile = findfile('fluxcalib', night, expid, camera)

            fiberflatfile = input_fiberflat[camera]

            cmd = "desi_compute_fluxcalibration"
            cmd += " --infile {}".format(framefile)
            cmd += " --sky {}".format(skyfile)
            cmd += " --fiberflat {}".format(fiberflatfile)
            cmd += " --models {}".format(stdfile)
            cmd += " --outfile {}".format(calibfile)
            cmd += " --delta-color-cut 0.1"

            inputs = [framefile, skyfile, fiberflatfile, stdfile]
            runcmd(cmd, inputs=inputs, outputs=[
                calibfile,
            ])

        timer.stop('fluxcalib')
        if comm is not None:
            comm.barrier()

    #-------------------------------------------------------------------------
    #- Applying flux calibration

    if args.obstype in [
            'SCIENCE',
    ] and (not args.noskysub) and (not args.nofluxcalib):

        night, expid = args.night, args.expid  #- shorter

        timer.start('applycalib')
        if rank == 0:
            log.info('Starting cframe file creation at {}'.format(
                time.asctime()))

        for camera in args.cameras[rank::size]:
            framefile = findfile('frame', night, expid, camera)
            skyfile = findfile('sky', night, expid, camera)
            spectrograph = int(camera[1])
            stdfile = findfile('stdstars',
                               night,
                               expid,
                               spectrograph=spectrograph)
            calibfile = findfile('fluxcalib', night, expid, camera)
            cframefile = findfile('cframe', night, expid, camera)

            fiberflatfile = input_fiberflat[camera]

            cmd = "desi_process_exposure"
            cmd += " --infile {}".format(framefile)
            cmd += " --fiberflat {}".format(fiberflatfile)
            cmd += " --sky {}".format(skyfile)
            cmd += " --calib {}".format(calibfile)
            cmd += " --outfile {}".format(cframefile)
            cmd += " --cosmics-nsig 6"
            if args.no_xtalk:
                cmd += " --no-xtalk"

            inputs = [framefile, fiberflatfile, skyfile, calibfile]
            runcmd(cmd, inputs=inputs, outputs=[
                cframefile,
            ])

        if comm is not None:
            comm.barrier()

        timer.stop('applycalib')

    #-------------------------------------------------------------------------
    #- Wrap up

    # if rank == 0:
    #     report = timer.report()
    #     log.info('Rank 0 timing report:\n' + report)

    if comm is not None:
        timers = comm.gather(timer, root=0)
    else:
        timers = [
            timer,
        ]

    if rank == 0:
        stats = desiutil.timer.compute_stats(timers)
        log.info('Timing summary statistics:\n' + json.dumps(stats, indent=2))

        if args.timingfile:
            if os.path.exists(args.timingfile):
                with open(args.timingfile) as fx:
                    previous_stats = json.load(fx)

                #- augment previous_stats with new entries, but don't overwrite old
                for name in stats:
                    if name not in previous_stats:
                        previous_stats[name] = stats[name]

                stats = previous_stats

            tmpfile = args.timingfile + '.tmp'
            with open(tmpfile, 'w') as fx:
                json.dump(stats, fx, indent=2)
            os.rename(tmpfile, args.timingfile)

    if rank == 0:
        log.info('All done at {}'.format(time.asctime()))
Пример #12
0
def main(args):

    log = get_logger()

    # precompute convolution kernels
    kernels = compute_crosstalk_kernels()

    A = None
    B = None
    out_wave = None

    dfiber = np.array([-2, -1, 1, 2])
    #dfiber=np.array([-1,1])

    npar = dfiber.size
    with_cst = True  # to marginalize over residual background (should not change much)
    if with_cst:
        npar += 1

    # one measurement per fiber bundle
    nfiber_per_bundle = 25
    nbundles = 500 // nfiber_per_bundle

    xtalks = []

    previous_psf_filename = None
    previous_fiberflat_filename = None

    for filename in args.infile:

        # read a frame and fiber the sky fibers
        frame = read_frame(filename)

        if out_wave is None:
            dwave = (frame.wave[-1] - frame.wave[0]) / 40
            out_wave = np.linspace(frame.wave[0] + dwave / 2,
                                   frame.wave[-1] - dwave / 2, 40)

        # find fiberflat
        if "FIBERFLT" in frame.meta.keys():
            flatname = frame.meta["FIBERFLT"]
            if flatname.find("SPCALIB") >= 0:
                flatname = flatname.replace(
                    "SPCALIB", os.environ["DESI_SPECTRO_CALIB"] + "/")
            if flatname.find("SPECPROD") >= 0:
                # this one is harder :-(
                dirname = os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.dirname(filename))))
                flatname = flatname.replace("SPECPROD", dirname + "/")

        else:
            flatname = findcalibfile([
                frame.meta,
            ], "FIBERFLAT")
        if flatname is not None:
            if previous_fiberflat_filename is not None and previous_fiberflat_filename == flatname:
                log.info("Using same fiberflat")
            else:
                if not os.path.isfile(flatname):
                    log.error("Cannot open fiberflat file {}".format(flatname))
                    raise IOError(
                        "Cannot open fiberflat file {}".format(flatname))
                log.info("Using fiberflat {}".format(flatname))
                fiberflat = read_fiberflat(flatname)
                medflat = np.median(fiberflat.fiberflat, axis=1)
                previous_fiberflat_filename = flatname
        else:
            medflat = None
            log.warning("No fiberflat")

        skyfibers = np.where((frame.fibermap["OBJTYPE"] == "SKY")
                             & (frame.fibermap["FIBERSTATUS"] == 0))[0]
        log.info("{} sky fibers in {}".format(skyfibers.size, filename))

        frame.ivar *= (
            (frame.mask == 0) | (frame.mask == specmask.BADFIBER)
        )  # ignore BADFIBER which is a statement on the positioning

        # also open trace set to determine the shift
        # to apply to adjacent spectra
        psf_filename = findcalibfile([
            frame.meta,
        ], "PSF")

        # only reread if necessary
        if previous_psf_filename is None or previous_psf_filename != psf_filename:
            tset = read_xytraceset(psf_filename)
            previous_psf_filename = psf_filename

        # will use this y
        central_y = tset.npix_y // 2

        mwave = np.mean(frame.wave)

        if A is None:
            A = np.zeros((nbundles, npar, npar, out_wave.size))
            B = np.zeros((nbundles, npar, out_wave.size))
            fA = np.zeros((npar, npar, out_wave.size))
            fB = np.zeros((npar, out_wave.size))
            ninput = np.zeros((nbundles, dfiber.size))

        for skyfiber in skyfibers:
            cflux = np.zeros((dfiber.size, out_wave.size))
            skyfiberbundle = skyfiber // nfiber_per_bundle

            nbad = np.sum(frame.ivar[skyfiber] == 0)
            if nbad > 200:
                if nbad < 2000:
                    log.warning(
                        "ignore skyfiber {} from {} with {} masked pixel".
                        format(skyfiber, filename, nbad))
                continue

            skyfiber_central_wave = tset.wave_vs_y(skyfiber, central_y)

            should_consider = False
            must_exclude = False
            fA *= 0.
            fB *= 0.

            use_median_filter = False  # not needed
            median_filter_width = 30
            skyfiberflux, skyfiberivar = resample_flux(out_wave, frame.wave,
                                                       frame.flux[skyfiber],
                                                       frame.ivar[skyfiber])
            if medflat is not None:
                skyfiberflux *= medflat[
                    skyfiber]  # apply relative transmission of fiber, i.e. undo the fiberflat correction

            if use_median_filter:
                good = (skyfiberivar > 0)
                skyfiberflux = np.interp(out_wave, out_wave[good],
                                         skyfiberflux[good])
                skyfiberflux = scipy.ndimage.filters.median_filter(
                    skyfiberflux, median_filter_width, mode='constant')

            for i, df in enumerate(dfiber):
                otherfiber = df + skyfiber
                if otherfiber < 0: continue
                if otherfiber >= frame.nspec: continue
                if otherfiber // nfiber_per_bundle != skyfiberbundle:
                    continue  # not same bundle

                snr = np.sqrt(frame.ivar[otherfiber]) * frame.flux[otherfiber]
                medsnr = np.median(snr)
                if medsnr > 2:  # need good SNR to model cross talk
                    should_consider = True  # in which case we need all of the contaminants to the sky fiber ...

                nbad = np.sum(snr == 0)
                if nbad > 200:
                    if nbad < 2000:
                        log.warning(
                            "ignore fiber {} from {} with {} masked pixel".
                            format(otherfiber, filename, nbad))
                    must_exclude = True  # because 1 bad fiber
                    break

                if np.any(snr > 1000.):
                    log.error(
                        "signal to noise is suspiciously too high in fiber {} from {}"
                        .format(otherfiber, filename))
                    must_exclude = True  # because 1 bad fiber
                    break

                # interpolate over masked pixels or low snr pixels and shift
                medivar = np.median(frame.ivar[otherfiber])
                good = (frame.ivar[otherfiber] > 0.01 * medivar
                        )  # interpolate over brigh sky lines

                # account for change of wavelength for same y coordinate
                otherfiber_central_wave = tset.wave_vs_y(otherfiber, central_y)
                flux = np.interp(
                    frame.wave +
                    (otherfiber_central_wave - skyfiber_central_wave),
                    frame.wave[good], frame.flux[otherfiber][good])
                if medflat is not None:
                    flux *= medflat[
                        otherfiber]  # apply relative transmission of fiber, i.e. undo the fiberflat correction

                if use_median_filter:
                    flux = scipy.ndimage.filters.median_filter(
                        flux, median_filter_width, mode='constant')
                kern = kernels[np.abs(df)]
                tmp = fftconvolve(flux, kern, mode="same")
                cflux[i] = resample_flux(out_wave, frame.wave, tmp)

                fB[i] = skyfiberivar * cflux[i] * skyfiberflux
                for j in range(i + 1):
                    fA[i, j] = skyfiberivar * cflux[i] * cflux[j]

            if should_consider and (not must_exclude):

                scflux = np.sum(cflux, axis=0)
                mscflux = np.sum(skyfiberivar * scflux) / np.sum(skyfiberivar)
                if mscflux < 100:
                    continue

                if with_cst:
                    i = dfiber.size
                    fA[i, i] = skyfiberivar  # constant term
                    fB[i] = skyfiberivar * skyfiberflux
                    for j in range(i):
                        fA[i, j] = skyfiberivar * cflux[j]

                # just stack all wavelength to get 1 number for this fiber
                scflux = np.sum(cflux[np.abs(dfiber) == 1], axis=0)
                a = np.sum(skyfiberivar * scflux**2)
                b = np.sum(skyfiberivar * scflux * skyfiberflux)
                xtalk = b / a
                err = 1. / np.sqrt(a)
                msky = np.sum(
                    skyfiberivar * skyfiberflux) / np.sum(skyfiberivar)
                ra = frame.fibermap["TARGET_RA"][skyfiber]
                dec = frame.fibermap["TARGET_DEC"][skyfiber]

                if np.abs(xtalk) > 0.02 and np.abs(xtalk) / err > 5:
                    log.warning(
                        "discard skyfiber = {}, xtalk = {:4.3f} +- {:4.3f}, ra = {:5.4f} , dec = {:5.4f}, sky fiber flux= {:4.3f}, cont= {:4.3f}"
                        .format(skyfiber, xtalk, err, ra, dec, msky, mscflux))
                    continue

                if err < 0.01 / 5.:
                    xtalks.append(xtalk)

                for i in range(dfiber.size):
                    ninput[skyfiberbundle,
                           i] += int(np.sum(fB[i]) != 0)  # to monitor
                B[skyfiberbundle] += fB
                A[skyfiberbundle] += fA

    for bundle in range(nbundles):
        for i in range(npar):
            for j in range(i):
                A[bundle, j, i] = A[bundle, i, j]

    # now solve
    crosstalk = np.zeros((nbundles, dfiber.size, out_wave.size))
    crosstalk_ivar = np.zeros((nbundles, dfiber.size, out_wave.size))
    for bundle in range(nbundles):
        for j in range(out_wave.size):
            try:
                Ai = np.linalg.inv(A[bundle, :, :, j])
                if with_cst:
                    crosstalk[bundle, :, j] = Ai.dot(
                        B[bundle, :, j])[:-1]  # last coefficient is constant
                    crosstalk_ivar[bundle, :, j] = 1. / np.diag(Ai)[:-1]
                else:
                    crosstalk[bundle, :, j] = Ai.dot(B[bundle, :, j])
                    crosstalk_ivar[bundle, :, j] = 1. / np.diag(Ai)

            except np.linalg.LinAlgError as e:
                pass

    table = Table()
    table["WAVELENGTH"] = out_wave
    for bundle in range(nbundles):
        for i, df in enumerate(dfiber):
            key = "CROSSTALK-B{:02d}-F{:+d}".format(bundle, df)
            table[key] = crosstalk[bundle, i]
            key = "CROSSTALKIVAR-B{:02d}-F{:+d}".format(bundle, df)
            table[key] = crosstalk_ivar[bundle, i]
            key = "NINPUT-B{:02d}-F{:+d}".format(bundle, df)
            table[key] = np.repeat(ninput[bundle, i], out_wave.size)

    table.write(args.outfile, overwrite=True)
    log.info("wrote {}".format(args.outfile))

    log.info("number of sky fibers used per bundle:")
    for bundle in range(nbundles):
        log.info("bundle {}: {}".format(bundle, ninput[bundle]))

    if args.plot:
        for bundle in range(nbundles):
            for i, df in enumerate(dfiber):
                err = 1. / np.sqrt(crosstalk_ivar[bundle, i] +
                                   (crosstalk_ivar[bundle, i] == 0))
                plt.errorbar(wave,
                             crosstalk[bundle, i],
                             err,
                             fmt="o-",
                             label="bundle = {:02d} dfiber = {:+d}".format(
                                 bundle, df))
        plt.grid()
        plt.legend()
        plt.show()