def get_new_frame(night, expid, camera, basedir, nsky, rep=0): '''For a given frame file, returns an updated frame file to the basedir with a certain number of sky fibers. Args: night: YYYYMMDD (float) expid: exposure id without padding zeros (float) camera: examples are b3, r4, z2, etc. (string) basedir: path to directory you want new frame to be written nsky: number of fibers flagged sky in new frame fibermap Options: rep: number of different frame files you want (with same number of sky fibers), default is 5 Writes a new frame file with desispec.io.write_frame().''' framefile = desispec.io.findfile('frame', night, expid, camera=camera) print(framefile) header = fitsio.read_header(framefile) fiberflatfile = findcalibfile([ header, ], 'FIBERFLAT') frame = desispec.io.read_frame(framefile) fiberflat = desispec.io.read_fiberflat(fiberflatfile) print(fiberflatfile) pick_sky_fibers(frame, fiberflat, nsky=nsky) #- output updated frame to current directory newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format( camera, expid, nsky, rep) desispec.io.write_frame(newframefile, frame)
def run_compute_sky(night, expid, cameras, basedir, nsky_list, reps=None): '''Generates sky models for new frame files, using --no-extra-variance option, which doesn't inflate the output errors for sky subtraction systematics. Args: night: YYYYMMDD (float) expid: exposure id without padding zeros (float) cameras: list of cameras corresponding to frame files in given directory, example ['r3', 'z3', 'b3'] (list or array) basedir: where to look for frame files nsky_list: list with different numbers of fibers frame files were generated with. Options: rep: number of different frame files for each camera and nsky combination. Default is 5 ''' if reps == None: reps = 5 for cam in cameras: framefile = desispec.io.findfile('frame', night, expid, camera=cam) header = fitsio.read_header(framefile) fiberflatfile = findcalibfile([ header, ], 'FIBERFLAT') for N in range(reps): for n in nsky_list: newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format( cam, expid, n, N) skyfile = basedir + '/sky-{}-{:08d}-{}-{}.fits'.format( cam, expid, n, N) cmd = 'desi_compute_sky -i {} --fiberflat {} -o {} --no-extra-variance'.format( newframefile, fiberflatfile, skyfile) print('RUNNING {}'.format(cmd)) err = subprocess.call(cmd.split()) if err: print('FAILED') else: print('OK')
def correct_fiber_crosstalk(frame,fiberflat=None,xyset=None): """Apply a fiber cross talk correction. Modifies frame.flux and frame.ivar. Args: frame : desispec.frame.Frame object Optionnal: fiberflat : desispec.fiberflat.FiberFlat object xyset : desispec.xytraceset.XYTraceSet object with trace coordinates to shift the spectra (automatically found with calibration finder otherwise) """ log=get_logger() params = read_crosstalk_parameters() if xyset is None : psf_filename = findcalibfile([frame.meta,],"PSF") xyset = read_xytraceset(psf_filename) log.info("compute kernels") kernels = compute_crosstalk_kernels() contamination = np.zeros(frame.flux.shape) contamination_var = np.zeros(frame.flux.shape) for dfiber in [-2,-1,1,2] : log.info("F{:+d}".format(dfiber)) kernel = kernels[np.abs(dfiber)] cont,var = compute_contamination(frame,dfiber,kernel,params,xyset,fiberflat) contamination += cont contamination_var += var frame.flux -= contamination frame_var = 1./(frame.ivar + (frame.ivar==0)) frame.ivar = (frame.ivar>0)/( frame_var + contamination_var )
def run_sky_subtraction(night, expid, cameras, basedir, nsky_list, reps=None): '''Runs sky subtraction with new sky models and frame files. Args: night: YYYYMMDD (float) expid: exposure id without padding zeros (float) cameras: list of cameras corresponding to frame and sky files in given directory, example ['r3', 'z3', 'b3'] (list or array) basedir: where to look for frame files nsky_list: list with different numbers of fibers frame files and sky files were generated with. Options: rep: number of different frame/sky files for each camera and nsky combination. Default is 5 ''' if reps == None: reps = 5 for cam in cameras: framefile = desispec.io.findfile('frame', night, expid, camera=cam) header = fitsio.read_header(framefile) fiberflatfile = findcalibfile([ header, ], 'FIBERFLAT') fiberflat = desispec.io.read_fiberflat(fiberflatfile) for N in range(reps): for n in nsky_list: newframefile = basedir + '/frame-{}-{:08d}-{}-{}.fits'.format( cam, expid, n, N) skyfile = basedir + '/sky-{}-{:08d}-{}-{}.fits'.format( cam, expid, n, N) sframe = desispec.io.read_frame(newframefile) sky = desispec.io.read_sky(skyfile) apply_fiberflat(sframe, fiberflat) subtract_sky(sframe, sky) sframefile = basedir + '/sframe-{}-{:08d}-{}-{}.fits'.format( cam, expid, n, N) desispec.io.write_frame(sframefile, sframe)
def calc_tsnr2_cframe(cframe): """ Given cframe, calc_tsnr2 guessing frame,fiberflat,skymodel,fluxcalib to use Args: cframe: input cframe Frame object Returns (results, alpha) from calc_tsnr2 """ log = get_logger() dirname, filename = os.path.split(cframe.filename) framefile = os.path.join(dirname, filename.replace('cframe', 'frame')) skyfile = os.path.join(dirname, filename.replace('cframe', 'sky')) fluxcalibfile = os.path.join(dirname, filename.replace('cframe', 'fluxcalib')) for testfile in (framefile, skyfile, fluxcalibfile): if not os.path.exists(testfile): msg = 'missing {testfile}; unable to calculate TSNR2' log.error(msg) raise ValueError(msg) night = cframe.meta['NIGHT'] expid = cframe.meta['EXPID'] camera = cframe.meta['CAMERA'] fiberflatfile = desispec.io.findfile('fiberflatnight', night, camera=camera) if not os.path.exists(fiberflatfile): ffname = os.path.basename(fiberflatfile) log.warning(f'{ffname} not found; using default calibs') fiberflatfile = findcalibfile([cframe.meta,], 'FIBERFLAT') frame = desispec.io.read_frame(framefile) fiberflat = desispec.io.read_fiberflat(fiberflatfile) skymodel = desispec.io.read_sky(skyfile) fluxcalib = desispec.io.read_flux_calibration(fluxcalibfile) return calc_tsnr2(frame, fiberflat, skymodel, fluxcalib)
def main(args, comm=None): log = get_logger() imgfile = args.input_image outfile = args.output_psf if args.input_psf is not None: inpsffile = args.input_psf else: from desispec.calibfinder import findcalibfile hdr = fits.getheader(imgfile) inpsffile = findcalibfile([ hdr, ], 'PSF') optarray = [] if args.extra is not None: optarray = args.extra.split() specmin = int(args.specmin) nspec = int(args.nspec) bundlesize = int(args.bundlesize) specmax = specmin + nspec # Now we divide our spectra into bundles checkbundles = set() checkbundles.update( np.floor_divide(np.arange(specmin, specmax), bundlesize * np.ones(nspec)).astype(int)) bundles = sorted(checkbundles) nbundle = len(bundles) bspecmin = {} bnspec = {} for b in bundles: if specmin > b * bundlesize: bspecmin[b] = specmin else: bspecmin[b] = b * bundlesize if (b + 1) * bundlesize > specmax: bnspec[b] = specmax - bspecmin[b] else: bnspec[b] = (b + 1) * bundlesize - bspecmin[b] # Now we assign bundles to processes nproc = 1 rank = 0 if comm is not None: nproc = comm.size rank = comm.rank mynbundle = int(nbundle / nproc) myfirstbundle = 0 leftover = nbundle % nproc if rank < leftover: mynbundle += 1 myfirstbundle = rank * mynbundle else: myfirstbundle = ((mynbundle + 1) * leftover) + \ (mynbundle * (rank - leftover)) if rank == 0: # Print parameters log.info("specex: using {} processes".format(nproc)) log.info("specex: input image = {}".format(imgfile)) log.info("specex: input PSF = {}".format(inpsffile)) log.info("specex: output = {}".format(outfile)) log.info("specex: bundlesize = {}".format(bundlesize)) log.info("specex: specmin = {}".format(specmin)) log.info("specex: specmax = {}".format(specmax)) if args.broken_fibers: log.info("specex: broken fibers = {}".format(args.broken_fibers)) # get the root output file outpat = re.compile(r'(.*)\.fits') outmat = outpat.match(outfile) if outmat is None: raise RuntimeError("specex output file should have .fits extension") outroot = outmat.group(1) outdir = os.path.dirname(outroot) if rank == 0: if not os.path.isdir(outdir): os.makedirs(outdir) failcount = 0 for b in range(myfirstbundle, myfirstbundle + mynbundle): outbundle = "{}_{:02d}".format(outroot, b) outbundlefits = "{}.fits".format(outbundle) com = ['desi_psf_fit'] com.extend(['-a', imgfile]) com.extend(['--in-psf', inpsffile]) com.extend(['--out-psf', outbundlefits]) com.extend(['--first-bundle', "{}".format(b)]) com.extend(['--last-bundle', "{}".format(b)]) com.extend(['--first-fiber', "{}".format(bspecmin[b])]) com.extend(['--last-fiber', "{}".format(bspecmin[b] + bnspec[b] - 1)]) if args.broken_fibers: com.extend(['--broken-fibers', "{}".format(args.broken_fibers)]) if args.debug: com.extend(['--debug']) com.extend(optarray) log.debug("proc {} calling {}".format(rank, " ".join(com))) argc = len(com) arg_buffers = [ct.create_string_buffer(com[i].encode('ascii')) \ for i in range(argc)] addrlist = [ ct.cast(x, ct.POINTER(ct.c_char)) for x in \ map(ct.addressof, arg_buffers) ] arg_pointers = (ct.POINTER(ct.c_char) * argc)(*addrlist) retval = libspecex.cspecex_desi_psf_fit(argc, arg_pointers) if retval != 0: comstr = " ".join(com) log.error("desi_psf_fit on process {} failed with return " "value {} running {}".format(rank, retval, comstr)) failcount += 1 if comm is not None: from mpi4py import MPI failcount = comm.allreduce(failcount, op=MPI.SUM) if failcount > 0: # all processes throw raise RuntimeError("some bundles failed desi_psf_fit") if rank == 0: outfits = "{}.fits".format(outroot) inputs = ["{}_{:02d}.fits".format(outroot, x) for x in bundles] if args.disable_merge: log.info("don't merge") else: #- Empirically it appears that files written by one rank sometimes #- aren't fully buffer-flushed and closed before getting here, #- despite the MPI allreduce barrier. Pause to let I/O catch up. log.info('HACK: taking a 20 sec pause before merging') sys.stdout.flush() time.sleep(20.) merge_psf(inputs, outfits) log.info('done merging') if failcount == 0: # only remove the per-bundle files if the merge was good for f in inputs: if os.path.isfile(f): os.remove(f) if comm is not None: failcount = comm.bcast(failcount, root=0) if failcount > 0: # all processes throw raise RuntimeError("merging of per-bundle files failed") return
def calc_tsnr2(frame, fiberflat, skymodel, fluxcalib, alpha_only=False) : ''' Compute template SNR^2 values for a given frame Args: frame : uncalibrated Frame object for one camera fiberflat : FiberFlat object sky : SkyModel object fluxcalib : FluxCalib object returns (tsnr2, alpha): `tsnr2` dictionary, with keys labeling tracer (bgs,elg,etc.), of values holding nfiber length array of the tsnr^2 values for this camera, and `alpha`, the relative weighting btwn rdnoise & sky terms to model var. Note: Assumes DESIMODEL is set and up to date. ''' global _camera_nea_angperpix global _band_ensemble log=get_logger() if not (frame.meta["BUNIT"]=="count/Angstrom" or frame.meta["BUNIT"]=="electron/Angstrom" ) : log.error("requires an uncalibrated frame") raise RuntimeError("requires an uncalibrated frame") camera=frame.meta["CAMERA"].strip().lower() band=camera[0] psfpath=findcalibfile([frame.meta],"PSF") psf=GaussHermitePSF(psfpath) # Returns bivariate spline to be evaluated at (fiber, wave). if not "DESIMODEL" in os.environ : msg = "requires $DESIMODEL to get the NEA and the SNR templates" log.error(msg) raise RuntimeError(msg) if _camera_nea_angperpix is None: _camera_nea_angperpix = dict() if camera in _camera_nea_angperpix: nea, angperpix = _camera_nea_angperpix[camera] else: neafilename=os.path.join(os.environ["DESIMODEL"], f"data/specpsf/nea/masternea_{camera}.fits") log.info("read NEA file {}".format(neafilename)) nea, angperpix = read_nea(neafilename) _camera_nea_angperpix[camera] = nea, angperpix if _band_ensemble is None: _band_ensemble = dict() if band in _band_ensemble: ensemble = _band_ensemble[band] else: ensembledir=os.path.join(os.environ["DESIMODEL"],"data/tsnr") log.info("read TSNR ensemble files in {}".format(ensembledir)) ensemble = get_ensemble(ensembledir, bands=[band,]) _band_ensemble[band] = ensemble nspec, nwave = fluxcalib.calib.shape fibers = np.arange(nspec) rdnoise = fb_rdnoise(fibers, frame, psf) # ebv = frame.fibermap['EBV'] if np.sum(ebv!=0)>0 : log.info("TSNR MEDIAN EBV = {:.3f}".format(np.median(ebv[ebv!=0]))) else : log.info("TSNR MEDIAN EBV = 0") # Evaluate. npix = nea(fibers, frame.wave) angperpix = angperpix(fibers, frame.wave) angperspecbin = np.mean(np.gradient(frame.wave)) for label, x in zip(['RDNOISE', 'NEA', 'ANGPERPIX', 'ANGPERSPECBIN'], [rdnoise, npix, angperpix, angperspecbin]): log.info('{} \t {:.3f} +- {:.3f}'.format(label.ljust(10), np.median(x), np.std(x))) # Relative weighting between rdnoise & sky terms to model var. alpha = calc_alpha(frame, fibermap=frame.fibermap, rdnoise_sigma=rdnoise, npix_1d=npix, angperpix=angperpix, angperspecbin=angperspecbin, fiberflat=fiberflat, skymodel=skymodel) log.info(f"TSNR ALPHA = {alpha:.6f}") if alpha_only: return {}, alpha maskfactor = np.ones_like(frame.mask, dtype=np.float) maskfactor[frame.mask > 0] = 0.0 maskfactor *= (frame.ivar > 0.0) tsnrs = {} denom = var_model(rdnoise, npix, angperpix, angperspecbin, fiberflat, skymodel, alpha=alpha) for tracer in ensemble.keys(): wave = ensemble[tracer].wave[band] dflux = ensemble[tracer].flux[band] if len(frame.wave) != len(wave) or not np.allclose(frame.wave, wave): log.warning(f'Resampling {tracer} ensemble wavelength to match input {camera} frame') tmp = np.zeros([dflux.shape[0], len(frame.wave)]) for i in range(dflux.shape[0]): tmp[i] = np.interp(frame.wave, wave, dflux[i], left=dflux[i,0], right=dflux[i,-1]) dflux = tmp wave = frame.wave # Work in uncalibrated flux units (electrons per angstrom); flux_calib includes exptime. tau. # Broadcast. dflux = dflux * fluxcalib.calib # [e/A] # Wavelength dependent fiber flat; Multiply or divide - check with Julien. result = dflux * fiberflat.fiberflat # Apply dust transmission. result *= dust_transmission(frame.wave, ebv) result = result**2. result /= denom # Eqn. (1) of https://desi.lbl.gov/DocDB/cgi-bin/private/RetrieveFile?docid=4723;filename=sky-monitor-mc-study-v1.pdf;version=2 tsnrs[tracer] = np.sum(result * maskfactor, axis=1) results=dict() for tracer in tsnrs.keys(): key = 'TSNR2_{}_{}'.format(tracer.upper(), band.upper()) results[key]=tsnrs[tracer] log.info('{} = {:.6f}'.format(key, np.median(tsnrs[tracer]))) return results, alpha
def main(args, comm=None): log = get_logger() #- only import when running, to avoid requiring specex install for import from specex.specex import run_specex imgfile = args.input_image outfile = args.output_psf nproc = 1 rank = 0 if comm is not None: nproc = comm.size rank = comm.rank hdr = None if rank == 0: hdr = fits.getheader(imgfile) if comm is not None: hdr = comm.bcast(hdr, root=0) #- Locate line list in $SPECEXDATA or specex/data if 'SPECEXDATA' in os.environ: specexdata = os.environ['SPECEXDATA'] else: from pkg_resources import resource_filename specexdata = resource_filename('specex', 'data') lamp_lines_file = os.path.join(specexdata, 'specex_linelist_desi.txt') if args.input_psf is not None: inpsffile = args.input_psf else: from desispec.calibfinder import findcalibfile inpsffile = findcalibfile([ hdr, ], 'PSF') optarray = [] if args.extra is not None: optarray = args.extra.split() specmin = int(args.specmin) nspec = int(args.nspec) bundlesize = int(args.bundlesize) specmax = specmin + nspec # Now we divide our spectra into bundles checkbundles = set() checkbundles.update( np.floor_divide(np.arange(specmin, specmax), bundlesize * np.ones(nspec)).astype(int)) bundles = sorted(checkbundles) nbundle = len(bundles) bspecmin = {} bnspec = {} for b in bundles: if specmin > b * bundlesize: bspecmin[b] = specmin else: bspecmin[b] = b * bundlesize if (b + 1) * bundlesize > specmax: bnspec[b] = specmax - bspecmin[b] else: bnspec[b] = (b + 1) * bundlesize - bspecmin[b] # Now we assign bundles to processes mynbundle = int(nbundle / nproc) leftover = nbundle % nproc if rank < leftover: mynbundle += 1 myfirstbundle = bundles[0] + rank * mynbundle else: myfirstbundle = bundles[0] + ((mynbundle + 1) * leftover) + \ (mynbundle * (rank - leftover)) if rank == 0: # Print parameters log.info("specex: using {} processes".format(nproc)) log.info("specex: input image = {}".format(imgfile)) log.info("specex: input PSF = {}".format(inpsffile)) log.info("specex: output = {}".format(outfile)) log.info("specex: bundlesize = {}".format(bundlesize)) log.info("specex: specmin = {}".format(specmin)) log.info("specex: specmax = {}".format(specmax)) if args.broken_fibers: log.info("specex: broken fibers = {}".format(args.broken_fibers)) # get the root output file outpat = re.compile(r'(.*)\.fits') outmat = outpat.match(outfile) if outmat is None: raise RuntimeError("specex output file should have .fits extension") outroot = outmat.group(1) outdir = os.path.dirname(outroot) if rank == 0: if outdir != "": if not os.path.isdir(outdir): os.makedirs(outdir) cam = hdr["camera"].lower().strip() band = cam[0] failcount = 0 for b in range(myfirstbundle, myfirstbundle + mynbundle): outbundle = "{}_{:02d}".format(outroot, b) outbundlefits = "{}.fits".format(outbundle) com = ['desi_psf_fit'] com.extend(['-a', imgfile]) com.extend(['--in-psf', inpsffile]) com.extend(['--out-psf', outbundlefits]) com.extend(['--lamp-lines', lamp_lines_file]) com.extend(['--first-bundle', "{}".format(b)]) com.extend(['--last-bundle', "{}".format(b)]) com.extend(['--first-fiber', "{}".format(bspecmin[b])]) com.extend(['--last-fiber', "{}".format(bspecmin[b] + bnspec[b] - 1)]) if band == "z": com.extend(['--legendre-deg-wave', "{}".format(3)]) com.extend(['--fit-continuum']) else: com.extend(['--legendre-deg-wave', "{}".format(1)]) if args.broken_fibers: com.extend(['--broken-fibers', "{}".format(args.broken_fibers)]) if args.debug: com.extend(['--debug']) com.extend(optarray) log.debug("proc {} calling {}".format(rank, " ".join(com))) retval = run_specex(com) if retval != 0: comstr = " ".join(com) log.error("desi_psf_fit on process {} failed with return " "value {} running {}".format(rank, retval, comstr)) failcount += 1 if comm is not None: from mpi4py import MPI failcount = comm.allreduce(failcount, op=MPI.SUM) if failcount > 0: # all processes throw raise RuntimeError("some bundles failed desi_psf_fit") if rank == 0: outfits = "{}.fits".format(outroot) inputs = ["{}_{:02d}.fits".format(outroot, x) for x in bundles] if args.disable_merge: log.info("don't merge") else: #- Empirically it appears that files written by one rank sometimes #- aren't fully buffer-flushed and closed before getting here, #- despite the MPI allreduce barrier. Pause to let I/O catch up. log.info('5 sec pause before merging') sys.stdout.flush() time.sleep(5.) merge_psf(inputs, outfits) log.info('done merging') if failcount == 0: # only remove the per-bundle files if the merge was good for f in inputs: if os.path.isfile(f): os.remove(f) if comm is not None: failcount = comm.bcast(failcount, root=0) if failcount > 0: # all processes throw raise RuntimeError("merging of per-bundle files failed") return
def main(args=None, comm=None): if args is None: args = parse() elif isinstance(args, (list, tuple)): args = parse(args) log = get_logger() start_mpi_connect = time.time() if comm is not None: #- Use the provided comm to determine rank and size rank = comm.rank size = comm.size else: #- Check MPI flags and determine the comm, rank, and size given the arguments comm, rank, size = assign_mpi(do_mpi=args.mpi, do_batch=args.batch, log=log) stop_mpi_connect = time.time() #- Start timer; only print log messages from rank 0 (others are silent) timer = desiutil.timer.Timer(silent=(rank > 0)) #- Fill in timing information for steps before we had the timer created if args.starttime is not None: timer.start('startup', starttime=args.starttime) timer.stop('startup', stoptime=start_imports) timer.start('imports', starttime=start_imports) timer.stop('imports', stoptime=stop_imports) timer.start('mpi_connect', starttime=start_mpi_connect) timer.stop('mpi_connect', stoptime=stop_mpi_connect) #- Freeze IERS after parsing args so that it doesn't bother if only --help timer.start('freeze_iers') desiutil.iers.freeze_iers() timer.stop('freeze_iers') #- Preflight checks timer.start('preflight') # - Preflight checks if rank > 0: # - Let rank 0 fetch these, and then broadcast args, hdr, camhdr = None, None, None else: if args.inputs is None: if args.night is None or args.expids is None: raise RuntimeError( 'Must specify --inputs or --night AND --expids') else: args.expids = np.array( args.expids.strip(' \t').split(',')).astype(int) args.inputs = [] for expid in args.expids: infile = findfile('raw', night=args.night, expid=expid) args.inputs.append(infile) if not os.path.isfile(infile): raise IOError('Missing input file: {}'.format(infile)) else: args.inputs = np.array(args.inputs.strip(' \t').split(',')) #- args.night will be defined in update_args_with_headers, #- but let's define the expids here #- NOTE: inputs has priority. Overwriting expids if they existed. args.expids = [] for infile in args.inputs: hdr = load_raw_data_header(pathname=infile, return_filehandle=False) args.expids.append(int(hdr['EXPID'])) args.expids = np.sort(args.expids) args.inputs = np.sort(args.inputs) args.expid = args.expids[0] args.input = args.inputs[0] #- Use header information to fill in missing information in the arguments object args, hdr, camhdr = update_args_with_headers(args) #- If not a science observation, we don't need the hdr or camhdr objects, #- So let's not broadcast them to all the ranks if args.obstype != 'SCIENCE': hdr, camhdr = None, None if comm is not None: args = comm.bcast(args, root=0) hdr = comm.bcast(hdr, root=0) camhdr = comm.bcast(camhdr, root=0) known_obstype = ['SCIENCE', 'ARC', 'FLAT'] if args.obstype not in known_obstype: raise RuntimeError('obstype {} not in {}'.format( args.obstype, known_obstype)) timer.stop('preflight') # ------------------------------------------------------------------------- # - Create and submit a batch job if requested if args.batch: #camword = create_camword(args.cameras) #exp_str = '-'.join('{:08d}'.format(expid) for expid in args.expids) if args.obstype.lower() == 'science': jobdesc = 'stdstarfit' elif args.obstype.lower() == 'arc': jobdesc = 'psfnight' elif args.obstype.lower() == 'flat': jobdesc = 'nightlyflat' else: jobdesc = args.obstype.lower() scriptfile = create_desi_proc_batch_script(night=args.night, exp=args.expids, cameras=args.cameras,\ jobdesc=jobdesc, queue=args.queue, runtime=args.runtime,\ batch_opts=args.batch_opts, timingfile=args.timingfile, system_name=args.system_name) err = 0 if not args.nosubmit: err = subprocess.call(['sbatch', scriptfile]) sys.exit(err) # ------------------------------------------------------------------------- # - Proceed with running # - What are we going to do? if rank == 0: log.info('----------') log.info('Input {}'.format(args.inputs)) log.info('Night {} expids {}'.format(args.night, args.expids)) log.info('Obstype {}'.format(args.obstype)) log.info('Cameras {}'.format(args.cameras)) log.info('Output root {}'.format(desispec.io.specprod_root())) log.info('----------') # - Wait for rank 0 to make directories before proceeding if comm is not None: comm.barrier() # ------------------------------------------------------------------------- # - Merge PSF of night if applicable if args.obstype in ['ARC']: timer.start('psfnight') num_cmd = num_err = 0 if rank == 0: for camera in args.cameras: psfnightfile = findfile('psfnight', args.night, args.expids[0], camera) if not os.path.isfile( psfnightfile ): # we still don't have a psf night, see if we can compute it ... psfs = [ findfile('psf', args.night, expid, camera).replace("psf", "fit-psf") for expid in args.expids ] log.info( "Number of PSF for night={} camera={} = {}".format( args.night, camera, len(psfs))) if len(psfs) > 4: # lets do it! log.info("Computing psfnight ...") dirname = os.path.dirname(psfnightfile) if not os.path.isdir(dirname): os.makedirs(dirname) num_cmd += 1 #- generic try/except so that any failure doesn't leave #- MPI rank 0 hanging while others are waiting for it try: desispec.scripts.specex.mean_psf( psfs, psfnightfile) except: log.error('specex.meanpsf failed for {}'.format( os.path.basename(psfnightfile))) exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception( exc_type, exc_value, exc_traceback) log.error(''.join(lines)) sys.stdout.flush() if not os.path.exists(psfnightfile): log.error(f'Failed to create {psfnightfile}') num_err += 1 else: log.info( "Fewer than 4 psfs were provided, can't compute psfnight. Exiting ..." ) num_cmd += 1 num_err += 1 timer.stop('psfnight') num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm) if rank == 0: if num_err > 0: log.error(f'{num_err}/{num_cmd} psfnight commands failed') if num_err > 0 and num_err == num_cmd: sys.stdout.flush() if rank == 0: log.critical('All psfnight commands failed') sys.exit(1) # ------------------------------------------------------------------------- # - Average and auto-calib fiberflats of night if applicable if args.obstype in ['FLAT']: timer.start('fiberflatnight') #- Track number of commands run and number of errors for exit code num_cmd = 0 num_err = 0 if rank == 0: fiberflatnightfile = findfile('fiberflatnight', args.night, args.expids[0], args.cameras[0]) fiberflatdirname = os.path.dirname(fiberflatnightfile) if os.path.isfile(fiberflatnightfile): log.info("Fiberflatnight already exists. Exitting ...") elif len( args.cameras ) < 6: # we still don't have them, see if we can compute them # , but need at least 2 spectros ... log.info( "Fewer than 6 cameras were available, so couldn't perform joint fit. Exiting ..." ) else: flats = [] for camera in args.cameras: for expid in args.expids: flats.append( findfile('fiberflat', args.night, expid, camera)) log.info("Number of fiberflat for night {} = {}".format( args.night, len(flats))) if len(flats) < 3 * 4 * len(args.cameras): log.info( "Fewer than 3 exposures with 4 lamps were available. Can't perform joint fit. Exiting..." ) else: log.info( "Computing fiberflatnight per lamp and camera ...") tmpdir = os.path.join(fiberflatdirname, "tmp") if not os.path.isdir(tmpdir): os.makedirs(tmpdir) log.info( "First average measurements per camera and per lamp") average_flats = dict() for camera in args.cameras: # list of flats for this camera flats_for_this_camera = [] for flat in flats: if flat.find(camera) >= 0: flats_for_this_camera.append(flat) # log.info("For camera {} , flats = {}".format(camera,flats_for_this_camera)) # sys.exit(12) # average per lamp (and camera) average_flats[camera] = list() for lampbox in range(4): ofile = os.path.join( tmpdir, "fiberflatnight-camera-{}-lamp-{}.fits".format( camera, lampbox)) if not os.path.isfile(ofile): log.info( "Average flat for camera {} and lamp box #{}" .format(camera, lampbox)) pg = "CALIB DESI-CALIB-0{} LEDs only".format( lampbox) cmd = "desi_average_fiberflat --program '{}' --outfile {} -i ".format( pg, ofile) for flat in flats_for_this_camera: cmd += " {} ".format(flat) num_cmd += 1 err = runcmd(cmd, inputs=flats_for_this_camera, outputs=[ ofile, ]) if err: num_err += 1 if os.path.isfile(ofile): average_flats[camera].append(ofile) else: log.error( f"Generating {ofile} failed; proceeding with other flats" ) else: log.info("Will use existing {}".format(ofile)) average_flats[camera].append(ofile) log.info( "Auto-calibration across lamps and spectro per camera arm (b,r,z)" ) for camera_arm in ["b", "r", "z"]: cameras_for_this_arm = [] flats_for_this_arm = [] for camera in args.cameras: if camera[0].lower() == camera_arm: cameras_for_this_arm.append(camera) if camera in average_flats: for flat in average_flats[camera]: flats_for_this_arm.append(flat) if len(flats_for_this_arm) > 0: cmd = "desi_autocalib_fiberflat --night {} --arm {} -i ".format( args.night, camera_arm) for flat in flats_for_this_arm: cmd += " {} ".format(flat) num_cmd += 1 err = runcmd(cmd, inputs=flats_for_this_arm, outputs=[]) if err: num_err += 1 else: log.error(f'No flats found for arm {camera_arm}') log.info("Done with fiber flats per night") timer.stop('fiberflatnight') num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm) if comm is not None: comm.barrier() if rank == 0: if num_err > 0: log.error(f'{num_err}/{num_cmd} fiberflat commands failed') if num_err > 0 and num_err == num_cmd: if rank == 0: log.critical('All fiberflat commands failed') sys.exit(1) ##################### Note ############################# ### Still for single exposure. Needs to be re-factored # ######################################################## if args.obstype in ['SCIENCE']: #inputfile = findfile('raw', night=args.night, expid=args.expids[0]) #if not os.path.isfile(inputfile): # raise IOError('Missing input file: {}'.format(inputfile)) ## - Fill in values from raw data header if not overridden by command line #fx = fitsio.FITS(inputfile) #if 'SPEC' in fx: # - 20200225 onwards # # hdr = fits.getheader(args.input, 'SPEC') # hdr = fx['SPEC'].read_header() #elif 'SPS' in fx: # - 20200224 and before # # hdr = fits.getheader(args.input, 'SPS') # hdr = fx['SPS'].read_header() #else: # # hdr = fits.getheader(args.input, 0) # hdr = fx[0].read_header() # #camhdr = dict() #for cam in args.cameras: # camhdr[cam] = fx[cam].read_header() #fx.close() timer.start('stdstarfit') num_err = num_cmd = 0 if rank == 0: log.info('Starting stdstar fitting at {}'.format(time.asctime())) # ------------------------------------------------------------------------- # - Get input fiberflat input_fiberflat = dict() if rank == 0: for camera in args.cameras: if args.fiberflat is not None: input_fiberflat[camera] = args.fiberflat elif args.calibnight is not None: # look for a fiberflatnight for this calib night fiberflatnightfile = findfile('fiberflatnight', args.calibnight, args.expids[0], camera) if not os.path.isfile(fiberflatnightfile): log.error("no {}".format(fiberflatnightfile)) raise IOError("no {}".format(fiberflatnightfile)) input_fiberflat[camera] = fiberflatnightfile else: # look for a fiberflatnight fiberflat fiberflatnightfile = findfile('fiberflatnight', args.night, args.expids[0], camera) if os.path.isfile(fiberflatnightfile): input_fiberflat[camera] = fiberflatnightfile elif args.most_recent_calib: # -- NOTE: Finding most recent only with respect to the first night nightfile = find_most_recent(args.night, file_type='fiberflatnight') if nightfile is None: input_fiberflat[camera] = findcalibfile( [hdr, camhdr[camera]], 'FIBERFLAT') else: input_fiberflat[camera] = nightfile else: input_fiberflat[camera] = findcalibfile( [hdr, camhdr[camera]], 'FIBERFLAT') log.info("Will use input FIBERFLAT: {}".format( input_fiberflat[camera])) if comm is not None: input_fiberflat = comm.bcast(input_fiberflat, root=0) # - Group inputs by spectrograph framefiles = dict() skyfiles = dict() fiberflatfiles = dict() for camera in args.cameras: sp = int(camera[1]) if sp not in framefiles: framefiles[sp] = list() skyfiles[sp] = list() fiberflatfiles[sp] = list() fiberflatfiles[sp].append(input_fiberflat[camera]) for expid in args.expids: framefiles[sp].append( findfile('frame', args.night, expid, camera)) skyfiles[sp].append(findfile('sky', args.night, expid, camera)) # - Hardcoded stdstar model version starmodels = os.path.join(os.getenv('DESI_BASIS_TEMPLATES'), 'stdstar_templates_v2.2.fits') # - Fit stdstars per spectrograph (not per-camera) spectro_nums = sorted(framefiles.keys()) ## for sp in spectro_nums[rank::size]: for i in range(rank, len(spectro_nums), size): sp = spectro_nums[i] # - NOTE: Saving the joint fit file with only the name of the first exposure stdfile = findfile('stdstars', args.night, args.expids[0], spectrograph=sp) #stdfile.replace('{:08d}'.format(args.expids[0]),'-'.join(['{:08d}'.format(eid) for eid in args.expids])) cmd = "desi_fit_stdstars" cmd += " --delta-color 0.1" cmd += " --frames {}".format(' '.join(framefiles[sp])) cmd += " --skymodels {}".format(' '.join(skyfiles[sp])) cmd += " --fiberflats {}".format(' '.join(fiberflatfiles[sp])) cmd += " --starmodels {}".format(starmodels) cmd += " --outfile {}".format(stdfile) if args.maxstdstars is not None: cmd += " --maxstdstars {}".format(args.maxstdstars) inputs = framefiles[sp] + skyfiles[sp] + fiberflatfiles[sp] num_cmd += 1 err = runcmd(cmd, inputs=inputs, outputs=[stdfile]) if err: num_err += 1 timer.stop('stdstarfit') num_cmd, num_err = mpi_count_failures(num_cmd, num_err, comm=comm) if comm is not None: comm.barrier() if rank == 0 and num_err > 0: log.error(f'{num_err}/{num_cmd} stdstar commands failed') sys.stdout.flush() if num_err > 0 and num_err == num_cmd: if rank == 0: log.critical('All stdstar commands failed') sys.exit(1) if rank == 0 and len(args.expids) > 1: for sp in spectro_nums: saved_stdfile = findfile('stdstars', args.night, args.expids[0], spectrograph=sp) for expid in args.expids[1:]: new_stdfile = findfile('stdstars', args.night, expid, spectrograph=sp) new_dirname, new_fname = os.path.split(new_stdfile) log.debug( "Path exists: {}, file exists: {}, link exists: {}". format(os.path.exists(new_stdfile), os.path.isfile(new_stdfile), os.path.islink(new_stdfile))) relpath_saved_std = os.path.relpath( saved_stdfile, new_dirname) log.debug(f'Sym Linking jointly fitted stdstar file: {new_stdfile} '+\ f'to existing file at rel. path {relpath_saved_std}') runcmd(os.symlink, args=(relpath_saved_std, new_stdfile), \ inputs=[saved_stdfile, ], outputs=[new_stdfile, ]) log.debug( "Path exists: {}, file exists: {}, link exists: {}". format(os.path.exists(new_stdfile), os.path.isfile(new_stdfile), os.path.islink(new_stdfile))) # ------------------------------------------------------------------------- # - Wrap up # if rank == 0: # report = timer.report() # log.info('Rank 0 timing report:\n' + report) if comm is not None: timers = comm.gather(timer, root=0) else: timers = [ timer, ] if rank == 0: stats = desiutil.timer.compute_stats(timers) log.info('Timing summary statistics:\n' + json.dumps(stats, indent=2)) if args.timingfile: if os.path.exists(args.timingfile): with open(args.timingfile) as fx: previous_stats = json.load(fx) #- augment previous_stats with new entries, but don't overwrite old for name in stats: if name not in previous_stats: previous_stats[name] = stats[name] stats = previous_stats tmpfile = args.timingfile + '.tmp' with open(tmpfile, 'w') as fx: json.dump(stats, fx, indent=2) os.rename(tmpfile, args.timingfile) if rank == 0: log.info('All done at {}'.format(time.asctime()))
def main(args, comm=None): log = get_logger() imgfile = args.input_image outfile = args.output_psf if args.input_psf is not None: inpsffile = args.input_psf else: from desispec.calibfinder import findcalibfile hdr = fits.getheader(imgfile) inpsffile = findcalibfile([hdr,], 'PSF') optarray = [] if args.extra is not None: optarray = args.extra.split() specmin = int(args.specmin) nspec = int(args.nspec) bundlesize = int(args.bundlesize) specmax = specmin + nspec # Now we divide our spectra into bundles checkbundles = set() checkbundles.update(np.floor_divide(np.arange(specmin, specmax), bundlesize*np.ones(nspec)).astype(int)) bundles = sorted(checkbundles) nbundle = len(bundles) bspecmin = {} bnspec = {} for b in bundles: if specmin > b * bundlesize: bspecmin[b] = specmin else: bspecmin[b] = b * bundlesize if (b+1) * bundlesize > specmax: bnspec[b] = specmax - bspecmin[b] else: bnspec[b] = (b+1) * bundlesize - bspecmin[b] # Now we assign bundles to processes nproc = 1 rank = 0 if comm is not None: nproc = comm.size rank = comm.rank mynbundle = int(nbundle / nproc) myfirstbundle = 0 leftover = nbundle % nproc if rank < leftover: mynbundle += 1 myfirstbundle = rank * mynbundle else: myfirstbundle = ((mynbundle + 1) * leftover) + \ (mynbundle * (rank - leftover)) if rank == 0: # Print parameters log.info("specex: using {} processes".format(nproc)) log.info("specex: input image = {}".format(imgfile)) log.info("specex: input PSF = {}".format(inpsffile)) log.info("specex: output = {}".format(outfile)) log.info("specex: bundlesize = {}".format(bundlesize)) log.info("specex: specmin = {}".format(specmin)) log.info("specex: specmax = {}".format(specmax)) # get the root output file outpat = re.compile(r'(.*)\.fits') outmat = outpat.match(outfile) if outmat is None: raise RuntimeError("specex output file should have .fits extension") outroot = outmat.group(1) outdir = os.path.dirname(outroot) if rank == 0: if not os.path.isdir(outdir): os.makedirs(outdir) failcount = 0 for b in range(myfirstbundle, myfirstbundle+mynbundle): outbundle = "{}_{:02d}".format(outroot, b) outbundlefits = "{}.fits".format(outbundle) com = ['desi_psf_fit'] com.extend(['-a', imgfile]) com.extend(['--in-psf', inpsffile]) com.extend(['--out-psf', outbundlefits]) com.extend(['--first-bundle', "{}".format(b)]) com.extend(['--last-bundle', "{}".format(b)]) com.extend(['--first-fiber', "{}".format(bspecmin[b])]) com.extend(['--last-fiber', "{}".format(bspecmin[b]+bnspec[b]-1)]) if args.debug : com.extend(['--debug']) com.extend(optarray) log.debug("proc {} calling {}".format(rank, " ".join(com))) argc = len(com) arg_buffers = [ct.create_string_buffer(com[i].encode('ascii')) \ for i in range(argc)] addrlist = [ ct.cast(x, ct.POINTER(ct.c_char)) for x in \ map(ct.addressof, arg_buffers) ] arg_pointers = (ct.POINTER(ct.c_char) * argc)(*addrlist) retval = libspecex.cspecex_desi_psf_fit(argc, arg_pointers) if retval != 0: comstr = " ".join(com) log.error("desi_psf_fit on process {} failed with return " "value {} running {}".format(rank, retval, comstr)) failcount += 1 if comm is not None: from mpi4py import MPI failcount = comm.allreduce(failcount, op=MPI.SUM) if failcount > 0: # all processes throw raise RuntimeError("some bundles failed desi_psf_fit") if rank == 0: outfits = "{}.fits".format(outroot) inputs = [ "{}_{:02d}.fits".format(outroot, x) for x in bundles ] merge_psf(inputs,outfits) if failcount == 0: # only remove the per-bundle files if the merge was good for f in inputs : if os.path.isfile(f): os.remove(f) if comm is not None: failcount = comm.bcast(failcount, root=0) if failcount > 0: # all processes throw raise RuntimeError("merging of per-bundle files failed") return
def main(args=None, comm=None): if args is None: args = parse() # elif isinstance(args, (list, tuple)): # args = parse(args) log = get_logger() start_mpi_connect = time.time() if comm is not None: #- Use the provided comm to determine rank and size rank = comm.rank size = comm.size else: #- Check MPI flags and determine the comm, rank, and size given the arguments comm, rank, size = assign_mpi(do_mpi=args.mpi, do_batch=args.batch, log=log) stop_mpi_connect = time.time() #- Start timer; only print log messages from rank 0 (others are silent) timer = desiutil.timer.Timer(silent=(rank > 0)) #- Fill in timing information for steps before we had the timer created if args.starttime is not None: timer.start('startup', starttime=args.starttime) timer.stop('startup', stoptime=start_imports) timer.start('imports', starttime=start_imports) timer.stop('imports', stoptime=stop_imports) timer.start('mpi_connect', starttime=start_mpi_connect) timer.stop('mpi_connect', stoptime=stop_mpi_connect) #- Freeze IERS after parsing args so that it doesn't bother if only --help timer.start('freeze_iers') desiutil.iers.freeze_iers() timer.stop('freeze_iers') #- Preflight checks timer.start('preflight') if rank > 0: #- Let rank 0 fetch these, and then broadcast args, hdr, camhdr = None, None, None else: args, hdr, camhdr = update_args_with_headers(args) ## Make sure badamps is formatted properly if comm is not None and rank == 0 and args.badamps is not None: args.badamps = validate_badamps(args.badamps) if comm is not None: args = comm.bcast(args, root=0) hdr = comm.bcast(hdr, root=0) camhdr = comm.bcast(camhdr, root=0) known_obstype = [ 'SCIENCE', 'ARC', 'FLAT', 'ZERO', 'DARK', 'TESTARC', 'TESTFLAT', 'PIXFLAT', 'SKY', 'TWILIGHT', 'OTHER' ] if args.obstype not in known_obstype: raise RuntimeError('obstype {} not in {}'.format( args.obstype, known_obstype)) timer.stop('preflight') #------------------------------------------------------------------------- #- Create and submit a batch job if requested if args.batch: #exp_str = '{:08d}'.format(args.expid) jobdesc = args.obstype.lower() if args.obstype == 'SCIENCE': # if not doing pre-stdstar fitting or stdstar fitting and if there is # no flag stopping flux calibration, set job to poststdstar if args.noprestdstarfit and args.nostdstarfit and ( not args.nofluxcalib): jobdesc = 'poststdstar' # elif told not to do std or post stdstar but the flag for prestdstar isn't set, # then perform prestdstar elif (not args.noprestdstarfit ) and args.nostdstarfit and args.nofluxcalib: jobdesc = 'prestdstar' #elif (not args.noprestdstarfit) and (not args.nostdstarfit) and (not args.nofluxcalib): # jobdesc = 'science' scriptfile = create_desi_proc_batch_script(night=args.night, exp=args.expid, cameras=args.cameras,\ jobdesc=jobdesc, queue=args.queue, runtime=args.runtime,\ batch_opts=args.batch_opts, timingfile=args.timingfile, system_name=args.system_name) err = 0 if not args.nosubmit: err = subprocess.call(['sbatch', scriptfile]) sys.exit(err) #------------------------------------------------------------------------- #- Proceeding with running #- What are we going to do? if rank == 0: log.info('----------') log.info('Input {}'.format(args.input)) log.info('Night {} expid {}'.format(args.night, args.expid)) log.info('Obstype {}'.format(args.obstype)) log.info('Cameras {}'.format(args.cameras)) log.info('Output root {}'.format(desispec.io.specprod_root())) log.info('----------') #- Create output directories if needed if rank == 0: preprocdir = os.path.dirname( findfile('preproc', args.night, args.expid, 'b0')) expdir = os.path.dirname( findfile('frame', args.night, args.expid, 'b0')) os.makedirs(preprocdir, exist_ok=True) os.makedirs(expdir, exist_ok=True) #- Wait for rank 0 to make directories before proceeding if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Preproc #- All obstypes get preprocessed timer.start('fibermap') #- Assemble fibermap for science exposures fibermap = None fibermap_ok = None if rank == 0 and args.obstype == 'SCIENCE': fibermap = findfile('fibermap', args.night, args.expid) if not os.path.exists(fibermap): tmp = findfile('preproc', args.night, args.expid, 'b0') preprocdir = os.path.dirname(tmp) fibermap = os.path.join(preprocdir, os.path.basename(fibermap)) log.info('Creating fibermap {}'.format(fibermap)) cmd = 'assemble_fibermap -n {} -e {} -o {}'.format( args.night, args.expid, fibermap) if args.badamps is not None: cmd += ' --badamps={}'.format(args.badamps) runcmd(cmd, inputs=[], outputs=[fibermap]) fibermap_ok = os.path.exists(fibermap) #- Some commissioning files didn't have coords* files that caused assemble_fibermap to fail #- these are well known failures with no other solution, so for those, just force creation #- of a fibermap with null coordinate information if not fibermap_ok and int(args.night) < 20200310: log.info( "Since night is before 20200310, trying to force fibermap creation without coords file" ) cmd += ' --force' runcmd(cmd, inputs=[], outputs=[fibermap]) fibermap_ok = os.path.exists(fibermap) #- If assemble_fibermap failed and obstype is SCIENCE, exit now if comm is not None: fibermap_ok = comm.bcast(fibermap_ok, root=0) if args.obstype == 'SCIENCE' and not fibermap_ok: sys.stdout.flush() if rank == 0: log.critical( 'assemble_fibermap failed for science exposure; exiting now') sys.exit(13) #- Wait for rank 0 to make fibermap if needed if comm is not None: fibermap = comm.bcast(fibermap, root=0) timer.stop('fibermap') if not (args.obstype in ['SCIENCE'] and args.noprestdstarfit): timer.start('preproc') for i in range(rank, len(args.cameras), size): camera = args.cameras[i] outfile = findfile('preproc', args.night, args.expid, camera) outdir = os.path.dirname(outfile) cmd = "desi_preproc -i {} -o {} --outdir {} --cameras {}".format( args.input, outfile, outdir, camera) if args.scattered_light: cmd += " --scattered-light" if fibermap is not None: cmd += " --fibermap {}".format(fibermap) if not args.obstype in ['ARC']: # never model variance for arcs if not args.no_model_pixel_variance: cmd += " --model-variance" runcmd(cmd, inputs=[args.input], outputs=[outfile]) timer.stop('preproc') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Get input PSFs timer.start('findpsf') input_psf = dict() if rank == 0: for camera in args.cameras: if args.psf is not None: input_psf[camera] = args.psf elif args.calibnight is not None: # look for a psfnight psf for this calib night psfnightfile = findfile('psfnight', args.calibnight, args.expid, camera) if not os.path.isfile(psfnightfile): log.error("no {}".format(psfnightfile)) raise IOError("no {}".format(psfnightfile)) input_psf[camera] = psfnightfile else: # look for a psfnight psf psfnightfile = findfile('psfnight', args.night, args.expid, camera) if os.path.isfile(psfnightfile): input_psf[camera] = psfnightfile elif args.most_recent_calib: nightfile = find_most_recent(args.night, file_type='psfnight') if nightfile is None: input_psf[camera] = findcalibfile( [hdr, camhdr[camera]], 'PSF') else: input_psf[camera] = nightfile else: input_psf[camera] = findcalibfile([hdr, camhdr[camera]], 'PSF') log.info("Will use input PSF : {}".format(input_psf[camera])) if comm is not None: input_psf = comm.bcast(input_psf, root=0) timer.stop('findpsf') #------------------------------------------------------------------------- #- Traceshift if ( args.obstype in ['FLAT', 'TESTFLAT', 'SKY', 'TWILIGHT'] ) or \ ( args.obstype in ['SCIENCE'] and (not args.noprestdstarfit) ): timer.start('traceshift') if rank == 0 and args.traceshift: log.info('Starting traceshift at {}'.format(time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] preprocfile = findfile('preproc', args.night, args.expid, camera) inpsf = input_psf[camera] outpsf = findfile('psf', args.night, args.expid, camera) if not os.path.isfile(outpsf): if args.traceshift: cmd = "desi_compute_trace_shifts" cmd += " -i {}".format(preprocfile) cmd += " --psf {}".format(inpsf) cmd += " --outpsf {}".format(outpsf) cmd += " --degxx 2 --degxy 0" if args.obstype in ['FLAT', 'TESTFLAT', 'TWILIGHT']: cmd += " --continuum" else: cmd += " --degyx 2 --degyy 0" if args.obstype in ['SCIENCE', 'SKY']: cmd += ' --sky' else: cmd = "ln -s {} {}".format(inpsf, outpsf) runcmd(cmd, inputs=[preprocfile, inpsf], outputs=[outpsf]) else: log.info("PSF {} exists".format(outpsf)) timer.stop('traceshift') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- PSF #- MPI parallelize this step if args.obstype in ['ARC', 'TESTARC']: timer.start('arc_traceshift') if rank == 0: log.info('Starting traceshift before specex PSF fit at {}'.format( time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] preprocfile = findfile('preproc', args.night, args.expid, camera) inpsf = input_psf[camera] outpsf = findfile('psf', args.night, args.expid, camera) outpsf = replace_prefix(outpsf, "psf", "shifted-input-psf") if not os.path.isfile(outpsf): cmd = "desi_compute_trace_shifts" cmd += " -i {}".format(preprocfile) cmd += " --psf {}".format(inpsf) cmd += " --outpsf {}".format(outpsf) cmd += " --degxx 0 --degxy 0 --degyx 0 --degyy 0" cmd += ' --arc-lamps' runcmd(cmd, inputs=[preprocfile, inpsf], outputs=[outpsf]) else: log.info("PSF {} exists".format(outpsf)) timer.stop('arc_traceshift') if comm is not None: comm.barrier() timer.start('psf') if rank == 0: log.info('Starting specex PSF fitting at {}'.format( time.asctime())) if rank > 0: cmds = inputs = outputs = None else: cmds = dict() inputs = dict() outputs = dict() for camera in args.cameras: preprocfile = findfile('preproc', args.night, args.expid, camera) tmpname = findfile('psf', args.night, args.expid, camera) inpsf = replace_prefix(tmpname, "psf", "shifted-input-psf") outpsf = replace_prefix(tmpname, "psf", "fit-psf") log.info("now run specex psf fit") cmd = 'desi_compute_psf' cmd += ' --input-image {}'.format(preprocfile) cmd += ' --input-psf {}'.format(inpsf) cmd += ' --output-psf {}'.format(outpsf) # look for fiber blacklist cfinder = CalibFinder([hdr, camhdr[camera]]) blacklistkey = "FIBERBLACKLIST" if not cfinder.haskey(blacklistkey) and cfinder.haskey( "BROKENFIBERS"): log.warning( "BROKENFIBERS yaml keyword deprecated, please use FIBERBLACKLIST" ) blacklistkey = "BROKENFIBERS" if cfinder.haskey(blacklistkey): blacklist = cfinder.value(blacklistkey) cmd += ' --broken-fibers {}'.format(blacklist) if rank == 0: log.warning('broken fibers: {}'.format(blacklist)) if not os.path.exists(outpsf): cmds[camera] = cmd inputs[camera] = [preprocfile, inpsf] outputs[camera] = [ outpsf, ] if comm is not None: cmds = comm.bcast(cmds, root=0) inputs = comm.bcast(inputs, root=0) outputs = comm.bcast(outputs, root=0) #- split communicator by 20 (number of bundles) group_size = 20 if (rank == 0) and (size % group_size != 0): log.warning( 'MPI size={} should be evenly divisible by {}'.format( size, group_size)) group = rank // group_size num_groups = (size + group_size - 1) // group_size comm_group = comm.Split(color=group) if rank == 0: log.info( f'Fitting PSFs with {num_groups} sub-communicators of size {group_size}' ) for i in range(group, len(args.cameras), num_groups): camera = args.cameras[i] if camera in cmds: cmdargs = cmds[camera].split()[1:] cmdargs = desispec.scripts.specex.parse(cmdargs) if comm_group.rank == 0: print('RUNNING: {}'.format(cmds[camera])) t0 = time.time() timestamp = time.asctime() log.info( f'MPI group {group} ranks {rank}-{rank+group_size-1} fitting PSF for {camera} at {timestamp}' ) try: desispec.scripts.specex.main(cmdargs, comm=comm_group) except Exception as e: if comm_group.rank == 0: log.error( f'FAILED: MPI group {group} ranks {rank}-{rank+group_size-1} camera {camera}' ) log.error('FAILED: {}'.format(cmds[camera])) log.error(e) if comm_group.rank == 0: specex_time = time.time() - t0 log.info( f'specex fit for {camera} took {specex_time:.1f} seconds' ) comm.barrier() else: log.warning( 'fitting PSFs without MPI parallelism; this will be SLOW') for camera in args.cameras: if camera in cmds: runcmd(cmds[camera], inputs=inputs[camera], outputs=outputs[camera]) if comm is not None: comm.barrier() # loop on all cameras and interpolate bad fibers for camera in args.cameras[rank::size]: t0 = time.time() log.info(f'Rank {rank} interpolating {camera} PSF over bad fibers') # look for fiber blacklist cfinder = CalibFinder([hdr, camhdr[camera]]) blacklistkey = "FIBERBLACKLIST" if not cfinder.haskey(blacklistkey) and cfinder.haskey( "BROKENFIBERS"): log.warning( "BROKENFIBERS yaml keyword deprecated, please use FIBERBLACKLIST" ) blacklistkey = "BROKENFIBERS" if cfinder.haskey(blacklistkey): fiberblacklist = cfinder.value(blacklistkey) tmpname = findfile('psf', args.night, args.expid, camera) inpsf = replace_prefix(tmpname, "psf", "fit-psf") outpsf = replace_prefix(tmpname, "psf", "fit-psf-fixed-blacklisted") if os.path.isfile(inpsf) and not os.path.isfile(outpsf): cmd = 'desi_interpolate_fiber_psf' cmd += ' --infile {}'.format(inpsf) cmd += ' --outfile {}'.format(outpsf) cmd += ' --fibers {}'.format(fiberblacklist) log.info( 'For camera {} interpolating PSF for broken fibers: {}' .format(camera, fiberblacklist)) runcmd(cmd, inputs=[inpsf], outputs=[outpsf]) if os.path.isfile(outpsf): os.rename( inpsf, inpsf.replace("fit-psf", "fit-psf-before-blacklisted-fix")) subprocess.call('cp {} {}'.format(outpsf, inpsf), shell=True) dt = time.time() - t0 log.info( f'Rank {rank} {camera} PSF interpolation took {dt:.1f} sec') timer.stop('psf') #------------------------------------------------------------------------- #- Merge PSF of night if applicable #if args.obstype in ['ARC']: if False: if rank == 0: for camera in args.cameras: psfnightfile = findfile('psfnight', args.night, args.expid, camera) if not os.path.isfile( psfnightfile ): # we still don't have a psf night, see if we can compute it ... psfs = glob.glob( findfile('psf', args.night, args.expid, camera).replace("psf", "fit-psf").replace( str(args.expid), "*")) log.info( "Number of PSF for night={} camera={} = {}".format( args.night, camera, len(psfs))) if len(psfs) > 4: # lets do it! log.info("Computing psfnight ...") dirname = os.path.dirname(psfnightfile) if not os.path.isdir(dirname): os.makedirs(dirname) desispec.scripts.specex.mean_psf(psfs, psfnightfile) if os.path.isfile(psfnightfile): # now use this one input_psf[camera] = psfnightfile #------------------------------------------------------------------------- #- Extract #- This is MPI parallel so handle a bit differently # maybe add ARC and TESTARC too if ( args.obstype in ['FLAT', 'TESTFLAT', 'SKY', 'TWILIGHT'] ) or \ ( args.obstype in ['SCIENCE'] and (not args.noprestdstarfit) ): timer.start('extract') if rank == 0: log.info('Starting extractions at {}'.format(time.asctime())) if rank > 0: cmds = inputs = outputs = None else: cmds = dict() inputs = dict() outputs = dict() for camera in args.cameras: cmd = 'desi_extract_spectra' #- Based on data from SM1-SM8, looking at central and edge fibers #- with in mind overlapping arc lamps lines if camera.startswith('b'): cmd += ' -w 3600.0,5800.0,0.8' elif camera.startswith('r'): cmd += ' -w 5760.0,7620.0,0.8' elif camera.startswith('z'): cmd += ' -w 7520.0,9824.0,0.8' preprocfile = findfile('preproc', args.night, args.expid, camera) psffile = findfile('psf', args.night, args.expid, camera) framefile = findfile('frame', args.night, args.expid, camera) cmd += ' -i {}'.format(preprocfile) cmd += ' -p {}'.format(psffile) cmd += ' -o {}'.format(framefile) cmd += ' --psferr 0.1' if args.obstype == 'SCIENCE' or args.obstype == 'SKY': if rank == 0: log.info('Include barycentric correction') cmd += ' --barycentric-correction' if not os.path.exists(framefile): cmds[camera] = cmd inputs[camera] = [preprocfile, psffile] outputs[camera] = [ framefile, ] #- TODO: refactor/combine this with PSF comm splitting logic if comm is not None: cmds = comm.bcast(cmds, root=0) inputs = comm.bcast(inputs, root=0) outputs = comm.bcast(outputs, root=0) #- split communicator by 20 (number of bundles) extract_size = 20 if (rank == 0) and (size % extract_size != 0): log.warning( 'MPI size={} should be evenly divisible by {}'.format( size, extract_size)) extract_group = rank // extract_size num_extract_groups = (size + extract_size - 1) // extract_size comm_extract = comm.Split(color=extract_group) for i in range(extract_group, len(args.cameras), num_extract_groups): camera = args.cameras[i] if camera in cmds: cmdargs = cmds[camera].split()[1:] extract_args = desispec.scripts.extract.parse(cmdargs) if comm_extract.rank == 0: print('RUNNING: {}'.format(cmds[camera])) desispec.scripts.extract.main_mpi(extract_args, comm=comm_extract) comm.barrier() else: log.warning( 'running extractions without MPI parallelism; this will be SLOW' ) for camera in args.cameras: if camera in cmds: runcmd(cmds[camera], inputs=inputs[camera], outputs=outputs[camera]) timer.stop('extract') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Fiberflat if args.obstype in ['FLAT', 'TESTFLAT']: timer.start('fiberflat') if rank == 0: log.info('Starting fiberflats at {}'.format(time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] framefile = findfile('frame', args.night, args.expid, camera) fiberflatfile = findfile('fiberflat', args.night, args.expid, camera) cmd = "desi_compute_fiberflat" cmd += " -i {}".format(framefile) cmd += " -o {}".format(fiberflatfile) runcmd(cmd, inputs=[ framefile, ], outputs=[ fiberflatfile, ]) timer.stop('fiberflat') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Average and auto-calib fiberflats of night if applicable #if args.obstype in ['FLAT']: if False: if rank == 0: fiberflatnightfile = findfile('fiberflatnight', args.night, args.expid, args.cameras[0]) fiberflatdirname = os.path.dirname(fiberflatnightfile) if not os.path.isfile(fiberflatnightfile) and len( args.cameras ) >= 6: # we still don't have them, see if we can compute them, but need at least 2 spectros ... flats = glob.glob( findfile('fiberflat', args.night, args.expid, "b0").replace(str(args.expid), "*").replace("b0", "*")) log.info("Number of fiberflat for night {} = {}".format( args.night, len(flats))) if len(flats) >= 3 * 4 * len( args.cameras ): # lets do it! (3 exposures x 4 lamps x N cameras) log.info( "Computing fiberflatnight per lamp and camera ...") tmpdir = os.path.join(fiberflatdirname, "tmp") if not os.path.isdir(tmpdir): os.makedirs(tmpdir) log.info( "First average measurements per camera and per lamp") average_flats = dict() for camera in args.cameras: # list of flats for this camera flats_for_this_camera = [] for flat in flats: if flat.find(camera) >= 0: flats_for_this_camera.append(flat) #log.info("For camera {} , flats = {}".format(camera,flats_for_this_camera)) #sys.exit(12) # average per lamp (and camera) average_flats[camera] = list() for lampbox in range(4): ofile = os.path.join( tmpdir, "fiberflatnight-camera-{}-lamp-{}.fits".format( camera, lampbox)) if not os.path.isfile(ofile): log.info( "Average flat for camera {} and lamp box #{}" .format(camera, lampbox)) pg = "CALIB DESI-CALIB-0{} LEDs only".format( lampbox) cmd = "desi_average_fiberflat --program '{}' --outfile {} -i ".format( pg, ofile) for flat in flats_for_this_camera: cmd += " {} ".format(flat) runcmd(cmd, inputs=flats_for_this_camera, outputs=[ ofile, ]) if os.path.isfile(ofile): average_flats[camera].append(ofile) else: log.info("Will use existing {}".format(ofile)) average_flats[camera].append(ofile) log.info( "Auto-calibration across lamps and spectro per camera arm (b,r,z)" ) for camera_arm in ["b", "r", "z"]: cameras_for_this_arm = [] flats_for_this_arm = [] for camera in args.cameras: if camera[0].lower() == camera_arm: cameras_for_this_arm.append(camera) if camera in average_flats: for flat in average_flats[camera]: flats_for_this_arm.append(flat) cmd = "desi_autocalib_fiberflat --night {} --arm {} -i ".format( args.night, camera_arm) for flat in flats_for_this_arm: cmd += " {} ".format(flat) runcmd(cmd, inputs=flats_for_this_arm, outputs=[]) log.info("Done with fiber flats per night") if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Get input fiberflat if args.obstype in ['SCIENCE', 'SKY'] and (not args.nofiberflat): timer.start('find_fiberflat') input_fiberflat = dict() if rank == 0: for camera in args.cameras: if args.fiberflat is not None: input_fiberflat[camera] = args.fiberflat elif args.calibnight is not None: # look for a fiberflatnight for this calib night fiberflatnightfile = findfile('fiberflatnight', args.calibnight, args.expid, camera) if not os.path.isfile(fiberflatnightfile): log.error("no {}".format(fiberflatnightfile)) raise IOError("no {}".format(fiberflatnightfile)) input_fiberflat[camera] = fiberflatnightfile else: # look for a fiberflatnight fiberflat fiberflatnightfile = findfile('fiberflatnight', args.night, args.expid, camera) if os.path.isfile(fiberflatnightfile): input_fiberflat[camera] = fiberflatnightfile elif args.most_recent_calib: nightfile = find_most_recent( args.night, file_type='fiberflatnight') if nightfile is None: input_fiberflat[camera] = findcalibfile( [hdr, camhdr[camera]], 'FIBERFLAT') else: input_fiberflat[camera] = nightfile else: input_fiberflat[camera] = findcalibfile( [hdr, camhdr[camera]], 'FIBERFLAT') log.info("Will use input FIBERFLAT: {}".format( input_fiberflat[camera])) if comm is not None: input_fiberflat = comm.bcast(input_fiberflat, root=0) timer.stop('find_fiberflat') #------------------------------------------------------------------------- #- Apply fiberflat and write fframe file if args.obstype in ['SCIENCE', 'SKY'] and args.fframe and \ ( not args.nofiberflat ) and (not args.noprestdstarfit): timer.start('apply_fiberflat') if rank == 0: log.info('Applying fiberflat at {}'.format(time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] fframefile = findfile('fframe', args.night, args.expid, camera) if not os.path.exists(fframefile): framefile = findfile('frame', args.night, args.expid, camera) fr = desispec.io.read_frame(framefile) flatfilename = input_fiberflat[camera] if flatfilename is not None: ff = desispec.io.read_fiberflat(flatfilename) fr.meta['FIBERFLT'] = desispec.io.shorten_filename( flatfilename) apply_fiberflat(fr, ff) fframefile = findfile('fframe', args.night, args.expid, camera) desispec.io.write_frame(fframefile, fr) else: log.warning( "Missing fiberflat for camera {}".format(camera)) timer.stop('apply_fiberflat') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Select random sky fibers (inplace update of frame file) #- TODO: move this to a function somewhere #- TODO: this assigns different sky fibers to each frame of same spectrograph if (args.obstype in [ 'SKY', 'SCIENCE' ]) and (not args.noskysub) and (not args.noprestdstarfit): timer.start('picksky') if rank == 0: log.info('Picking sky fibers at {}'.format(time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] framefile = findfile('frame', args.night, args.expid, camera) orig_frame = desispec.io.read_frame(framefile) #- Make a copy so that we can apply fiberflat fr = deepcopy(orig_frame) if np.any(fr.fibermap['OBJTYPE'] == 'SKY'): log.info('{} sky fibers already set; skipping'.format( os.path.basename(framefile))) continue #- Apply fiberflat then select random fibers below a flux cut flatfilename = input_fiberflat[camera] if flatfilename is None: log.error("No fiberflat for {}".format(camera)) continue ff = desispec.io.read_fiberflat(flatfilename) apply_fiberflat(fr, ff) sumflux = np.sum(fr.flux, axis=1) fluxcut = np.percentile(sumflux, 30) iisky = np.where(sumflux < fluxcut)[0] iisky = np.random.choice(iisky, size=100, replace=False) #- Update fibermap or original frame and write out orig_frame.fibermap['OBJTYPE'][iisky] = 'SKY' orig_frame.fibermap['DESI_TARGET'][iisky] |= desi_mask.SKY desispec.io.write_frame(framefile, orig_frame) timer.stop('picksky') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Sky subtraction if args.obstype in [ 'SCIENCE', 'SKY' ] and (not args.noskysub) and (not args.noprestdstarfit): timer.start('skysub') if rank == 0: log.info('Starting sky subtraction at {}'.format(time.asctime())) for i in range(rank, len(args.cameras), size): camera = args.cameras[i] framefile = findfile('frame', args.night, args.expid, camera) hdr = fitsio.read_header(framefile, 'FLUX') fiberflatfile = input_fiberflat[camera] if fiberflatfile is None: log.error("No fiberflat for {}".format(camera)) continue skyfile = findfile('sky', args.night, args.expid, camera) cmd = "desi_compute_sky" cmd += " -i {}".format(framefile) cmd += " --fiberflat {}".format(fiberflatfile) cmd += " --o {}".format(skyfile) if args.no_extra_variance: cmd += " --no-extra-variance" if not args.no_sky_wavelength_adjustment: cmd += " --adjust-wavelength" if not args.no_sky_lsf_adjustment: cmd += " --adjust-lsf" runcmd(cmd, inputs=[framefile, fiberflatfile], outputs=[ skyfile, ]) #- sframe = flatfielded sky-subtracted but not flux calibrated frame #- Note: this re-reads and re-does steps previously done for picking #- sky fibers; desi_proc is about human efficiency, #- not I/O or CPU efficiency... sframefile = desispec.io.findfile('sframe', args.night, args.expid, camera) if not os.path.exists(sframefile): frame = desispec.io.read_frame(framefile) fiberflat = desispec.io.read_fiberflat(fiberflatfile) sky = desispec.io.read_sky(skyfile) apply_fiberflat(frame, fiberflat) subtract_sky(frame, sky, apply_throughput_correction=True) frame.meta['IN_SKY'] = shorten_filename(skyfile) frame.meta['FIBERFLT'] = shorten_filename(fiberflatfile) desispec.io.write_frame(sframefile, frame) timer.stop('skysub') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Standard Star Fitting if args.obstype in ['SCIENCE',] and \ (not args.noskysub ) and \ (not args.nostdstarfit) : timer.start('stdstarfit') if rank == 0: log.info('Starting flux calibration at {}'.format(time.asctime())) #- Group inputs by spectrograph framefiles = dict() skyfiles = dict() fiberflatfiles = dict() night, expid = args.night, args.expid #- shorter for camera in args.cameras: sp = int(camera[1]) if sp not in framefiles: framefiles[sp] = list() skyfiles[sp] = list() fiberflatfiles[sp] = list() framefiles[sp].append(findfile('frame', night, expid, camera)) skyfiles[sp].append(findfile('sky', night, expid, camera)) fiberflatfiles[sp].append(input_fiberflat[camera]) #- Hardcoded stdstar model version starmodels = os.path.join(os.getenv('DESI_BASIS_TEMPLATES'), 'stdstar_templates_v2.2.fits') #- Fit stdstars per spectrograph (not per-camera) spectro_nums = sorted(framefiles.keys()) ## for sp in spectro_nums[rank::size]: for i in range(rank, len(spectro_nums), size): sp = spectro_nums[i] stdfile = findfile('stdstars', night, expid, spectrograph=sp) cmd = "desi_fit_stdstars" cmd += " --frames {}".format(' '.join(framefiles[sp])) cmd += " --skymodels {}".format(' '.join(skyfiles[sp])) cmd += " --fiberflats {}".format(' '.join(fiberflatfiles[sp])) cmd += " --starmodels {}".format(starmodels) cmd += " --outfile {}".format(stdfile) cmd += " --delta-color 0.1" if args.maxstdstars is not None: cmd += " --maxstdstars {}".format(args.maxstdstars) inputs = framefiles[sp] + skyfiles[sp] + fiberflatfiles[sp] runcmd(cmd, inputs=inputs, outputs=[stdfile]) timer.stop('stdstarfit') if comm is not None: comm.barrier() # ------------------------------------------------------------------------- # - Flux calibration if args.obstype in ['SCIENCE'] and \ (not args.noskysub) and \ (not args.nofluxcalib): timer.start('fluxcalib') night, expid = args.night, args.expid #- shorter #- Compute flux calibration vectors per camera for camera in args.cameras[rank::size]: framefile = findfile('frame', night, expid, camera) skyfile = findfile('sky', night, expid, camera) spectrograph = int(camera[1]) stdfile = findfile('stdstars', night, expid, spectrograph=spectrograph) calibfile = findfile('fluxcalib', night, expid, camera) fiberflatfile = input_fiberflat[camera] cmd = "desi_compute_fluxcalibration" cmd += " --infile {}".format(framefile) cmd += " --sky {}".format(skyfile) cmd += " --fiberflat {}".format(fiberflatfile) cmd += " --models {}".format(stdfile) cmd += " --outfile {}".format(calibfile) cmd += " --delta-color-cut 0.1" inputs = [framefile, skyfile, fiberflatfile, stdfile] runcmd(cmd, inputs=inputs, outputs=[ calibfile, ]) timer.stop('fluxcalib') if comm is not None: comm.barrier() #------------------------------------------------------------------------- #- Applying flux calibration if args.obstype in [ 'SCIENCE', ] and (not args.noskysub) and (not args.nofluxcalib): night, expid = args.night, args.expid #- shorter timer.start('applycalib') if rank == 0: log.info('Starting cframe file creation at {}'.format( time.asctime())) for camera in args.cameras[rank::size]: framefile = findfile('frame', night, expid, camera) skyfile = findfile('sky', night, expid, camera) spectrograph = int(camera[1]) stdfile = findfile('stdstars', night, expid, spectrograph=spectrograph) calibfile = findfile('fluxcalib', night, expid, camera) cframefile = findfile('cframe', night, expid, camera) fiberflatfile = input_fiberflat[camera] cmd = "desi_process_exposure" cmd += " --infile {}".format(framefile) cmd += " --fiberflat {}".format(fiberflatfile) cmd += " --sky {}".format(skyfile) cmd += " --calib {}".format(calibfile) cmd += " --outfile {}".format(cframefile) cmd += " --cosmics-nsig 6" if args.no_xtalk: cmd += " --no-xtalk" inputs = [framefile, fiberflatfile, skyfile, calibfile] runcmd(cmd, inputs=inputs, outputs=[ cframefile, ]) if comm is not None: comm.barrier() timer.stop('applycalib') #------------------------------------------------------------------------- #- Wrap up # if rank == 0: # report = timer.report() # log.info('Rank 0 timing report:\n' + report) if comm is not None: timers = comm.gather(timer, root=0) else: timers = [ timer, ] if rank == 0: stats = desiutil.timer.compute_stats(timers) log.info('Timing summary statistics:\n' + json.dumps(stats, indent=2)) if args.timingfile: if os.path.exists(args.timingfile): with open(args.timingfile) as fx: previous_stats = json.load(fx) #- augment previous_stats with new entries, but don't overwrite old for name in stats: if name not in previous_stats: previous_stats[name] = stats[name] stats = previous_stats tmpfile = args.timingfile + '.tmp' with open(tmpfile, 'w') as fx: json.dump(stats, fx, indent=2) os.rename(tmpfile, args.timingfile) if rank == 0: log.info('All done at {}'.format(time.asctime()))
def main(args): log = get_logger() # precompute convolution kernels kernels = compute_crosstalk_kernels() A = None B = None out_wave = None dfiber = np.array([-2, -1, 1, 2]) #dfiber=np.array([-1,1]) npar = dfiber.size with_cst = True # to marginalize over residual background (should not change much) if with_cst: npar += 1 # one measurement per fiber bundle nfiber_per_bundle = 25 nbundles = 500 // nfiber_per_bundle xtalks = [] previous_psf_filename = None previous_fiberflat_filename = None for filename in args.infile: # read a frame and fiber the sky fibers frame = read_frame(filename) if out_wave is None: dwave = (frame.wave[-1] - frame.wave[0]) / 40 out_wave = np.linspace(frame.wave[0] + dwave / 2, frame.wave[-1] - dwave / 2, 40) # find fiberflat if "FIBERFLT" in frame.meta.keys(): flatname = frame.meta["FIBERFLT"] if flatname.find("SPCALIB") >= 0: flatname = flatname.replace( "SPCALIB", os.environ["DESI_SPECTRO_CALIB"] + "/") if flatname.find("SPECPROD") >= 0: # this one is harder :-( dirname = os.path.dirname( os.path.dirname(os.path.dirname( os.path.dirname(filename)))) flatname = flatname.replace("SPECPROD", dirname + "/") else: flatname = findcalibfile([ frame.meta, ], "FIBERFLAT") if flatname is not None: if previous_fiberflat_filename is not None and previous_fiberflat_filename == flatname: log.info("Using same fiberflat") else: if not os.path.isfile(flatname): log.error("Cannot open fiberflat file {}".format(flatname)) raise IOError( "Cannot open fiberflat file {}".format(flatname)) log.info("Using fiberflat {}".format(flatname)) fiberflat = read_fiberflat(flatname) medflat = np.median(fiberflat.fiberflat, axis=1) previous_fiberflat_filename = flatname else: medflat = None log.warning("No fiberflat") skyfibers = np.where((frame.fibermap["OBJTYPE"] == "SKY") & (frame.fibermap["FIBERSTATUS"] == 0))[0] log.info("{} sky fibers in {}".format(skyfibers.size, filename)) frame.ivar *= ( (frame.mask == 0) | (frame.mask == specmask.BADFIBER) ) # ignore BADFIBER which is a statement on the positioning # also open trace set to determine the shift # to apply to adjacent spectra psf_filename = findcalibfile([ frame.meta, ], "PSF") # only reread if necessary if previous_psf_filename is None or previous_psf_filename != psf_filename: tset = read_xytraceset(psf_filename) previous_psf_filename = psf_filename # will use this y central_y = tset.npix_y // 2 mwave = np.mean(frame.wave) if A is None: A = np.zeros((nbundles, npar, npar, out_wave.size)) B = np.zeros((nbundles, npar, out_wave.size)) fA = np.zeros((npar, npar, out_wave.size)) fB = np.zeros((npar, out_wave.size)) ninput = np.zeros((nbundles, dfiber.size)) for skyfiber in skyfibers: cflux = np.zeros((dfiber.size, out_wave.size)) skyfiberbundle = skyfiber // nfiber_per_bundle nbad = np.sum(frame.ivar[skyfiber] == 0) if nbad > 200: if nbad < 2000: log.warning( "ignore skyfiber {} from {} with {} masked pixel". format(skyfiber, filename, nbad)) continue skyfiber_central_wave = tset.wave_vs_y(skyfiber, central_y) should_consider = False must_exclude = False fA *= 0. fB *= 0. use_median_filter = False # not needed median_filter_width = 30 skyfiberflux, skyfiberivar = resample_flux(out_wave, frame.wave, frame.flux[skyfiber], frame.ivar[skyfiber]) if medflat is not None: skyfiberflux *= medflat[ skyfiber] # apply relative transmission of fiber, i.e. undo the fiberflat correction if use_median_filter: good = (skyfiberivar > 0) skyfiberflux = np.interp(out_wave, out_wave[good], skyfiberflux[good]) skyfiberflux = scipy.ndimage.filters.median_filter( skyfiberflux, median_filter_width, mode='constant') for i, df in enumerate(dfiber): otherfiber = df + skyfiber if otherfiber < 0: continue if otherfiber >= frame.nspec: continue if otherfiber // nfiber_per_bundle != skyfiberbundle: continue # not same bundle snr = np.sqrt(frame.ivar[otherfiber]) * frame.flux[otherfiber] medsnr = np.median(snr) if medsnr > 2: # need good SNR to model cross talk should_consider = True # in which case we need all of the contaminants to the sky fiber ... nbad = np.sum(snr == 0) if nbad > 200: if nbad < 2000: log.warning( "ignore fiber {} from {} with {} masked pixel". format(otherfiber, filename, nbad)) must_exclude = True # because 1 bad fiber break if np.any(snr > 1000.): log.error( "signal to noise is suspiciously too high in fiber {} from {}" .format(otherfiber, filename)) must_exclude = True # because 1 bad fiber break # interpolate over masked pixels or low snr pixels and shift medivar = np.median(frame.ivar[otherfiber]) good = (frame.ivar[otherfiber] > 0.01 * medivar ) # interpolate over brigh sky lines # account for change of wavelength for same y coordinate otherfiber_central_wave = tset.wave_vs_y(otherfiber, central_y) flux = np.interp( frame.wave + (otherfiber_central_wave - skyfiber_central_wave), frame.wave[good], frame.flux[otherfiber][good]) if medflat is not None: flux *= medflat[ otherfiber] # apply relative transmission of fiber, i.e. undo the fiberflat correction if use_median_filter: flux = scipy.ndimage.filters.median_filter( flux, median_filter_width, mode='constant') kern = kernels[np.abs(df)] tmp = fftconvolve(flux, kern, mode="same") cflux[i] = resample_flux(out_wave, frame.wave, tmp) fB[i] = skyfiberivar * cflux[i] * skyfiberflux for j in range(i + 1): fA[i, j] = skyfiberivar * cflux[i] * cflux[j] if should_consider and (not must_exclude): scflux = np.sum(cflux, axis=0) mscflux = np.sum(skyfiberivar * scflux) / np.sum(skyfiberivar) if mscflux < 100: continue if with_cst: i = dfiber.size fA[i, i] = skyfiberivar # constant term fB[i] = skyfiberivar * skyfiberflux for j in range(i): fA[i, j] = skyfiberivar * cflux[j] # just stack all wavelength to get 1 number for this fiber scflux = np.sum(cflux[np.abs(dfiber) == 1], axis=0) a = np.sum(skyfiberivar * scflux**2) b = np.sum(skyfiberivar * scflux * skyfiberflux) xtalk = b / a err = 1. / np.sqrt(a) msky = np.sum( skyfiberivar * skyfiberflux) / np.sum(skyfiberivar) ra = frame.fibermap["TARGET_RA"][skyfiber] dec = frame.fibermap["TARGET_DEC"][skyfiber] if np.abs(xtalk) > 0.02 and np.abs(xtalk) / err > 5: log.warning( "discard skyfiber = {}, xtalk = {:4.3f} +- {:4.3f}, ra = {:5.4f} , dec = {:5.4f}, sky fiber flux= {:4.3f}, cont= {:4.3f}" .format(skyfiber, xtalk, err, ra, dec, msky, mscflux)) continue if err < 0.01 / 5.: xtalks.append(xtalk) for i in range(dfiber.size): ninput[skyfiberbundle, i] += int(np.sum(fB[i]) != 0) # to monitor B[skyfiberbundle] += fB A[skyfiberbundle] += fA for bundle in range(nbundles): for i in range(npar): for j in range(i): A[bundle, j, i] = A[bundle, i, j] # now solve crosstalk = np.zeros((nbundles, dfiber.size, out_wave.size)) crosstalk_ivar = np.zeros((nbundles, dfiber.size, out_wave.size)) for bundle in range(nbundles): for j in range(out_wave.size): try: Ai = np.linalg.inv(A[bundle, :, :, j]) if with_cst: crosstalk[bundle, :, j] = Ai.dot( B[bundle, :, j])[:-1] # last coefficient is constant crosstalk_ivar[bundle, :, j] = 1. / np.diag(Ai)[:-1] else: crosstalk[bundle, :, j] = Ai.dot(B[bundle, :, j]) crosstalk_ivar[bundle, :, j] = 1. / np.diag(Ai) except np.linalg.LinAlgError as e: pass table = Table() table["WAVELENGTH"] = out_wave for bundle in range(nbundles): for i, df in enumerate(dfiber): key = "CROSSTALK-B{:02d}-F{:+d}".format(bundle, df) table[key] = crosstalk[bundle, i] key = "CROSSTALKIVAR-B{:02d}-F{:+d}".format(bundle, df) table[key] = crosstalk_ivar[bundle, i] key = "NINPUT-B{:02d}-F{:+d}".format(bundle, df) table[key] = np.repeat(ninput[bundle, i], out_wave.size) table.write(args.outfile, overwrite=True) log.info("wrote {}".format(args.outfile)) log.info("number of sky fibers used per bundle:") for bundle in range(nbundles): log.info("bundle {}: {}".format(bundle, ninput[bundle])) if args.plot: for bundle in range(nbundles): for i, df in enumerate(dfiber): err = 1. / np.sqrt(crosstalk_ivar[bundle, i] + (crosstalk_ivar[bundle, i] == 0)) plt.errorbar(wave, crosstalk[bundle, i], err, fmt="o-", label="bundle = {:02d} dfiber = {:+d}".format( bundle, df)) plt.grid() plt.legend() plt.show()