Ejemplo n.º 1
0
def spline_fit(output_wave,
               input_wave,
               input_flux,
               required_resolution,
               input_ivar=None,
               order=3):
    """Performs a spline fit.
    """
    if input_ivar is not None:
        selection = np.where(input_ivar > 0)[0]
        if selection.size < 2:
            log = get_logger()
            log.error(
                "cannot do spline fit because only {0:d} values with ivar>0".
                format(selection.size))
            raise Error
        w1 = input_wave[selection[0]]
        w2 = input_wave[selection[-1]]
    else:
        w1 = input_wave[0]
        w2 = input_wave[-1]

    res = required_resolution
    n = int((w2 - w1) / res)
    res = (w2 - w1) / (n + 1)
    knots = w1 + res * (0.5 + np.arange(n))
    toto = scipy.interpolate.splrep(input_wave,
                                    input_flux,
                                    w=input_ivar,
                                    k=order,
                                    task=-1,
                                    t=knots)
    output_flux = scipy.interpolate.splev(output_wave, toto)
    return output_flux
Ejemplo n.º 2
0
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--infile',
        type=str,
        default=None,
        required=True,
        help=
        'path of DESI frame fits file corresponding to a continuum lamp exposure'
    )
    parser.add_argument('--outfile',
                        type=str,
                        default=None,
                        required=True,
                        help='path of DESI fiberflat fits file')

    args = parser.parse_args()
    log = get_logger()

    log.info("starting")

    frame = read_frame(args.infile)
    fiberflat = compute_fiberflat(frame)
    write_fiberflat(args.outfile, fiberflat, frame.header)

    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 3
0
def compatible(head1, head2) :
    log = get_logger()
    for k in ["PSFTYPE","NPIX_X","NPIX_Y","HSIZEX","HSIZEY","BUNDLMIN","BUNDLMAX","FIBERMAX","FIBERMIN","FIBERMAX","NPARAMS","LEGDEG","GHDEGX","GHDEGY"] :
        if (head1[k] != head2[k]) :
            log.warning("different {} : {}, {}".format(k, head1[k], head2[k]))
            return False
    return True
Ejemplo n.º 4
0
def apply_fiberflat(frame, fiberflat):
    ### def apply_fiberflat(flux,ivar,wave,fiberflat,ffivar,ffmask,ffwave):
    """Apply fiberflat to frame.  Modifies frame.flux and frame.ivar
    """
    log = get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, fiberflat.wave):
        message = "frame and fiberflat do not have the same wavelength arrays"
        log.critical(message)
        raise ValueError(message)
    """
     F'=F/C
     Var(F') = Var(F)/C**2 + F**2*(  d(1/C)/dC )**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*(1/C**2)**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*Var(C)/C**4
             = 1/(ivar(F)*C**2) + F**2/(ivar(C)*C**4)
    """
    #- shorthand
    ff = fiberflat
    sp = frame  #- sp=spectra for this frame

    sp.flux = sp.flux * (ff.fiberflat > 0) / (ff.fiberflat +
                                              (ff.fiberflat == 0))
    sp.ivar = (sp.ivar > 0) * (ff.ivar > 0) * (ff.fiberflat > 0) / (1. / (
        (sp.ivar + (sp.ivar == 0)) *
        (ff.fiberflat**2 +
         (ff.fiberflat == 0))) + sp.flux**2 / (ff.ivar * ff.fiberflat**4 +
                                               (ff.ivar * ff.fiberflat == 0)))

    log.info("done")
Ejemplo n.º 5
0
def main(args) :

    log=get_logger()
    log.info("starting")

    # Process
    frame = read_frame(args.infile)
    fiberflat = compute_fiberflat(frame)

    # QA
    if (args.qafile is not None):
        log.info("performing fiberflat QA")
        # Load
        qaframe = load_qa_frame(args.qafile, frame, flavor=frame.meta['FLAVOR'])
        # Run
        qaframe.run_qa('FIBERFLAT', (frame, fiberflat))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_fiberflat(args.qafig, qaframe, frame, fiberflat)

    # Write
    write_fiberflat(args.outfile, fiberflat, frame.meta)
    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 6
0
def spline_fit(output_wave,input_wave,input_flux,required_resolution,input_ivar=None,order=3):
    """Performs spline fit of input_flux vs. input_wave and resamples at output_wave
    
    Args:
        output_wave : 1D array of output wavelength samples
        input_wave : 1D array of input wavelengths
        input_flux : 1D array of input flux density
        required_resolution (float) : resolution for spline knot placement
    
    Options:
        input_ivar : 1D array of weights for input_flux
        order (int) : spline order
        
    Returns:
        output_flux : 1D array of flux sampled at output_wave    
    """
    if input_ivar is not None :
        selection=np.where(input_ivar>0)[0]
        if selection.size < 2 :
            log=get_logger()
            log.error("cannot do spline fit because only {0:d} values with ivar>0".format(selection.size))
            raise ValueError
        w1=input_wave[selection[0]]
        w2=input_wave[selection[-1]]
    else :
        w1=input_wave[0]
        w2=input_wave[-1]

    res=required_resolution
    n=int((w2-w1)/res)
    res=(w2-w1)/(n+1)
    knots=w1+res*(0.5+np.arange(n))
    toto=scipy.interpolate.splrep(input_wave,input_flux,w=input_ivar,k=order,task=-1,t=knots)
    output_flux = scipy.interpolate.splev(output_wave,toto)
    return output_flux
Ejemplo n.º 7
0
def main(args):

    log = get_logger()
    log.info("starting")

    # Process
    frame = read_frame(args.infile)
    fiberflat = compute_fiberflat(frame)

    # QA
    if (args.qafile is not None):
        log.info("performing fiberflat QA")
        # Load
        qaframe = load_qa_frame(args.qafile,
                                frame,
                                flavor=frame.meta['FLAVOR'])
        # Run
        qaframe.run_qa('FIBERFLAT', (frame, fiberflat))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_fiberflat(args.qafig, qaframe, frame, fiberflat)

    # Write
    write_fiberflat(args.outfile, fiberflat, frame.meta)
    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 8
0
def apply_fiberflat(frame, fiberflat):
### def apply_fiberflat(flux,ivar,wave,fiberflat,ffivar,ffmask,ffwave):
    """Apply fiberflat to frame.  Modifies frame.flux and frame.ivar
    """
    log=get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, fiberflat.wave):
        message = "frame and fiberflat do not have the same wavelength arrays"
        log.critical(message)
        raise ValueError(message)

    """
     F'=F/C
     Var(F') = Var(F)/C**2 + F**2*(  d(1/C)/dC )**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*(1/C**2)**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*Var(C)/C**4
             = 1/(ivar(F)*C**2) + F**2/(ivar(C)*C**4)
    """
    #- shorthand
    ff = fiberflat
    sp = frame  #- sp=spectra for this frame
    
    sp.flux = sp.flux*(ff.fiberflat>0)/(ff.fiberflat+(ff.fiberflat==0))
    sp.ivar=(sp.ivar>0)*(ff.ivar>0)*(ff.fiberflat>0)/( 1./((sp.ivar+(sp.ivar==0))*(ff.fiberflat**2+(ff.fiberflat==0))) + sp.flux**2/(ff.ivar*ff.fiberflat**4+(ff.ivar*ff.fiberflat==0)) )

    log.info("done")
Ejemplo n.º 9
0
    def __init__(self, wave, flux=None, ivar=None, mask=None, resolution=None):
        assert wave.ndim == 1, "Input wavelength should be 1D"
        assert (flux is None) or (
            flux.shape == wave.shape), "wave and flux should have same shape"
        assert (ivar is None) or (
            ivar.shape == wave.shape), "wave and ivar should have same shape"
        assert (mask is None) or (
            mask.shape == wave.shape), "wave and mask should have same shape"
        assert (resolution is None) or (isinstance(
            resolution, desispec.resolution.Resolution))
        assert (resolution is None) or (resolution.shape[0] == len(wave)
                                        ), "resolution size mismatch to wave"

        self.wave = wave
        self.flux = flux
        self.ivar = ivar
        self.mask = mask
        self.resolution = resolution
        self.R = resolution  #- shorthand
        self.log = get_logger()
        # Initialize the quantities we will accumulate during co-addition. Note that our
        # internal Cinv is a dense matrix.
        if ivar is None:
            n = len(wave)
            self.Cinv = np.zeros((n, n))
            self.Cinv_f = np.zeros((n, ))
        else:
            assert flux is not None and resolution is not None, 'Missing flux and/or resolution.'
            diag_ivar = scipy.sparse.dia_matrix((ivar[np.newaxis, :], [0]),
                                                resolution.shape)
            self.Cinv = self.resolution.T.dot(diag_ivar.dot(self.resolution))
            self.Cinv_f = self.resolution.T.dot(self.ivar * self.flux)
Ejemplo n.º 10
0
 def __init__(self,wave,flux=None,ivar=None,mask=None,resolution=None):
     assert wave.ndim == 1, "Input wavelength should be 1D"
     assert (flux is None) or (flux.shape == wave.shape), "wave and flux should have same shape"
     assert (ivar is None) or (ivar.shape == wave.shape), "wave and ivar should have same shape"
     assert (mask is None) or (mask.shape == wave.shape), "wave and mask should have same shape"
     assert (resolution is None) or (isinstance(resolution, desispec.resolution.Resolution))
     assert (resolution is None) or (resolution.shape[0] == len(wave)), "resolution size mismatch to wave"
     
     self.wave = wave
     self.flux = flux
     self.ivar = ivar
     self.mask = mask
     # if mask is None:
     #     self.mask = np.zeros(self.flux.shape, dtype=np.uint32)
     # else:
     #     self.mask = util.mask32(mask)
     self.resolution = resolution
     self.R = resolution #- shorthand
     self.log = get_logger()
     # Initialize the quantities we will accumulate during co-addition. Note that our
     # internal Cinv is a dense matrix.
     if ivar is None:
         n = len(wave)
         self.Cinv = np.zeros((n,n))
         self.Cinv_f = np.zeros((n,))
     else:
         assert flux is not None and resolution is not None,'Missing flux and/or resolution.'
         diag_ivar = scipy.sparse.dia_matrix((ivar[np.newaxis,:],[0]),resolution.shape)
         self.Cinv = self.resolution.T.dot(diag_ivar.dot(self.resolution))
         self.Cinv_f = self.resolution.T.dot(self.ivar*self.flux)
Ejemplo n.º 11
0
def check_env():
    """
    Check required environment variables; raise RuntimeException if missing
    """
    log = get_logger()
    #- template locations
    missing_template_env = False
    for objtype in ('ELG', 'LRG', 'STD', 'QSO'):
        name = 'DESI_'+objtype+'_TEMPLATES'
        if name not in os.environ:
            log.warning('missing ${0} needed for simulating spectra'.format(name))
            missing_template_env = True

    if missing_template_env:
        log.warning('    e.g. see NERSC:/project/projectdirs/desi/datachallenge/dc2/templates/')

    missing_env = False
    for name in (
        'DESI_SPECTRO_SIM', 'DESI_SPECTRO_DATA', 'DESI_SPECTRO_REDUX', 'PIXPROD', 'PRODNAME'):
        if name not in os.environ:
            log.warning("missing ${0}".format(name))
            missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD/")
        log.warning("    Raw data read from $DESI_SPECTRO_DATA/")
        log.warning("    Spectro pipeline output written to $DESI_SPECTRO_REDUX/$PRODNAME/")

    #- Wait until end to raise exception so that we report everything that
    #- is missing before actually failing
    if missing_env or missing_template_env:
        log.critical("missing env vars; exiting without running pipeline")
        sys.exit(1)
Ejemplo n.º 12
0
def qa_zbest(param, zf, brick):
    """
    Args:
        param : dict of QA parameters
        zf : ZfindBase object
        brick : brick dict
    Returns:
        qa_zbest: dict
    """
    log = get_logger()

    # Parse brick table
    key = list(brick.keys())[0]
    btbl = brick[key].hdu_list[4].data

    # Output dict
    qadict = {}

    # Failures
    nfail = np.sum(zf.zwarn > 0)
    qadict['NFAIL'] = int(nfail)  # For yaml
    if nfail > param['MAX_NFAIL']:
        log.warning("High number of failed redshifts {:d}".format(nfail))

    # Simple redshift stats
    gdz = zf.zwarn == 0
    qadict['MEAN_Z'] = float(np.mean(zf.z[gdz]))
    qadict['MEDIAN_Z'] = float(np.median(zf.z[gdz]))
    qadict['RMS_Z'] = float(np.std(zf.z[gdz]))

    # Match zf ID to brick ID
    srt = np.argsort(btbl['TARGETID'])
    left = np.searchsorted(btbl['TARGETID'],
                           zf.targetid,
                           side='left',
                           sorter=srt)

    # Types (ELG, QSO, LRG, STAR, ??) -- Need to allow for multiple of target options
    qadict['NTYPE'] = dict(ELG=0, QSO=0, LRG=0, STAR=0, UNKWN=0, MATCH=0)
    for kk, ztype in enumerate(zf.spectype):
        # Brick index
        idx = srt[left[kk]]
        #
        if ztype in param['ELG_TYPES']:
            qadict['NTYPE']['ELG'] += 1
            if btbl[idx]['OBJTYPE'] in param['ELG_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        elif ztype in param['QSO_TYPES']:
            qadict['NTYPE']['QSO'] += 1
            if btbl[idx]['OBJTYPE'] in param['QSO_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        elif ztype in param['STAR_TYPES']:
            qadict['NTYPE']['STAR'] += 1
            if btbl[idx]['OBJTYPE'] in param['STAR_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        else:
            qadict['NTYPE']['UNKWN'] += 1

    # Return
    return qadict
Ejemplo n.º 13
0
def qa_zbest(param, zf, brick):
    """
    Args:
        param : dict of QA parameters
        zf : ZfindBase object
        brick : brick dict
    Returns:
        qa_zbest: dict
    """
    log = get_logger()

    # Parse brick table
    key = brick.keys()[0]
    btbl = brick[key].hdu_list[4].data

    # Output dict
    qadict = {}

    # Failures
    nfail = np.sum(zf.zwarn > 0)
    qadict['NFAIL'] = int(nfail)  # For yaml
    if nfail > param['MAX_NFAIL']:
        log.warn("High number of failed redshifts {:d}".format(nfail))

    # Simple redshift stats
    gdz = zf.zwarn == 0
    qadict['MEAN_Z'] = float(np.mean(zf.z[gdz]))
    qadict['MEDIAN_Z'] = float(np.median(zf.z[gdz]))
    qadict['RMS_Z'] = float(np.std(zf.z[gdz]))

    # Match zf ID to brick ID
    srt = np.argsort(btbl['TARGETID'])
    left = np.searchsorted(btbl['TARGETID'], zf.targetid,
                        side='left',sorter=srt)

    # Types (ELG, QSO, LRG, STAR, ??) -- Need to allow for multiple of target options
    qadict['NTYPE'] = dict(ELG=0, QSO=0, LRG=0, STAR=0, UNKWN=0, MATCH=0)
    for kk,ztype in enumerate(zf.spectype):
        # Brick index
        idx = srt[left[kk]]
        #
        if ztype in param['ELG_TYPES']:
            qadict['NTYPE']['ELG'] += 1
            if btbl[idx]['OBJTYPE'] in param['ELG_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        elif ztype in param['QSO_TYPES']:
            qadict['NTYPE']['QSO'] += 1
            if btbl[idx]['OBJTYPE'] in param['QSO_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        elif ztype in param['STAR_TYPES']:
            qadict['NTYPE']['STAR'] += 1
            if btbl[idx]['OBJTYPE'] in param['STAR_TYPES']:
                qadict['NTYPE']['MATCH'] += 1
        else:
            qadict['NTYPE']['UNKWN'] += 1

    # Return
    return qadict
Ejemplo n.º 14
0
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--infile',
                        type=str,
                        default=None,
                        required=True,
                        help='path of DESI exposure frame fits file')
    parser.add_argument('--fibermap',
                        type=str,
                        default=None,
                        required=True,
                        help='path of DESI exposure frame fits file')
    parser.add_argument('--fiberflat',
                        type=str,
                        default=None,
                        required=True,
                        help='path of DESI fiberflat fits file')
    parser.add_argument('--outfile',
                        type=str,
                        default=None,
                        required=True,
                        help='path of DESI sky fits file')

    args = parser.parse_args()
    log = get_logger()

    log.info("starting")

    # read exposure to load data and get range of spectra
    frame = read_frame(args.infile)
    specmin = frame.header["SPECMIN"]
    specmax = frame.header["SPECMAX"]

    # read fibermap to locate sky fibers
    fibermap = read_fibermap(args.fibermap)
    selection = np.where((fibermap["OBJTYPE"] == "SKY")
                         & (fibermap["FIBER"] >= specmin)
                         & (fibermap["FIBER"] <= specmax))[0]
    if selection.size == 0:
        log.error("no sky fiber in fibermap %s" % args.fibermap)
        sys.exit(12)

    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat to sky fibers
    apply_fiberflat(frame, fiberflat)

    # compute sky model
    skymodel = compute_sky(frame, fibermap)

    # write result
    write_sky(args.outfile, skymodel, frame.header)

    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 15
0
def main(args) :

    log=get_logger()

    log.info("read frame")
    # read frame
    frame = read_frame(args.infile)

    log.info("apply fiberflat")
    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat
    apply_fiberflat(frame, fiberflat)

    log.info("subtract sky")
    # read sky
    skymodel=read_sky(args.sky)

    # subtract sky
    subtract_sky(frame, skymodel)

    log.info("compute flux calibration")

    # read models
    model_flux,model_wave,model_fibers=read_stdstar_models(args.models)

    # check that the model_fibers are actually standard stars
    fibermap = frame.fibermap
    model_fibers = model_fibers%500
    if np.any(fibermap['OBJTYPE'][model_fibers] != 'STD'):
        for i in model_fibers:
            log.error("inconsistency with spectrum %d, OBJTYPE='%s' in fibermap"%(i,fibermap["OBJTYPE"][i]))
        sys.exit(12)

    fluxcalib = compute_flux_calibration(frame, model_wave, model_flux)

    # QA
    if (args.qafile is not None):
        log.info("performing fluxcalib QA")
        # Load
        qaframe = load_qa_frame(args.qafile, frame, flavor=frame.meta['FLAVOR'])
        # Run
        #import pdb; pdb.set_trace()
        qaframe.run_qa('FLUXCALIB', (frame, fluxcalib))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_fluxcalib(args.qafig, qaframe, frame, fluxcalib)

    # write result
    write_flux_calibration(args.outfile, fluxcalib, header=frame.meta)

    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 16
0
def bin_bounds(x):
    """Calculates the bin boundaries of an array `x`.

    Returns tuple of lower and upper bounds, each with same length as `x`.
    """
    if x.size < 2:
        get_logger().error("bin_bounds, x.size={0:d}".format(x.size))
        exit(12)
    tx = np.sort(x)
    x_minus = np.roll(tx, 1)
    x_minus[0] = x_minus[1] + tx[0] - tx[1]
    x_plus = np.roll(tx, -1)
    x_plus[-1] = x_plus[-2] + tx[-1] - tx[-2]
    x_minus = 0.5 * (x + x_minus)
    x_plus = 0.5 * (x + x_plus)

    del tx
    return x_minus, x_plus
Ejemplo n.º 17
0
def bin_bounds(x):
    """Calculates the bin boundaries of an array `x`.

    Returns tuple of lower and upper bounds, each with same length as `x`.
    """
    if x.size<2 :
        get_logger().error("bin_bounds, x.size={0:d}".format(x.size))
        exit(12)
    tx=np.sort(x)
    x_minus=np.roll(tx,1)
    x_minus[0]=x_minus[1]+tx[0]-tx[1]
    x_plus=np.roll(tx,-1)
    x_plus[-1]=x_plus[-2]+tx[-1]-tx[-2]
    x_minus=0.5*(x+x_minus)
    x_plus=0.5*(x+x_plus)

    del tx
    return x_minus,x_plus
Ejemplo n.º 18
0
def main(args):

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='ELG archetypes')
    parser.add_argument('-o', '--objtype', type=str,  default='ELG', help='ELG', metavar='')

    # Set up the logger.
    if args.verbose:
        log = get_logger(DEBUG)
    else:
        log = get_logger()
        
    objtype = args.objtype.upper()
    log.debug('Using OBJTYPE {}'.format(objtype))

    baseflux, basewave, basemeta = read_basis_templates(objtype=objtype)



    pdb.set_trace()
Ejemplo n.º 19
0
def compatible(head1, head2):
    log = get_logger()
    for k in [
            "PSFTYPE", "NPIX_X", "NPIX_Y", "HSIZEX", "HSIZEY", "BUNDLMIN",
            "BUNDLMAX", "FIBERMAX", "FIBERMIN", "FIBERMAX", "NPARAMS",
            "LEGDEG", "GHDEGX", "GHDEGY"
    ]:
        if (head1[k] != head2[k]):
            log.warning("different {} : {}, {}".format(k, head1[k], head2[k]))
            return False
    return True
Ejemplo n.º 20
0
def main() :

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--infile', type = str, default = None, required=True,
                        help = 'path of DESI exposure frame fits file')
    parser.add_argument('--fiberflat', type = str, default = None,
                        help = 'path of DESI fiberflat fits file')
    parser.add_argument('--sky', type = str, default = None,
                        help = 'path of DESI sky fits file')
    parser.add_argument('--calib', type = str, default = None,
                        help = 'path of DESI calibration fits file')
    parser.add_argument('--outfile', type = str, default = None, required=True,
                        help = 'path of DESI sky fits file')
    # add calibration here when exists

    args = parser.parse_args()
    log = get_logger()

    if (args.fiberflat is None) and (args.sky is None) and (args.calib is None):
        log.critical('no --fiberflat, --sky, or --calib; nothing to do ?!?')
        sys.exit(12)

    frame = read_frame(args.infile)
    
    if args.fiberflat!=None :
        log.info("apply fiberflat")
        # read fiberflat
        fiberflat = read_fiberflat(args.fiberflat)

        # apply fiberflat to sky fibers
        apply_fiberflat(frame, fiberflat)

    if args.sky!=None :
        log.info("subtract sky")
        # read sky
        skymodel=read_sky(args.sky)
        # subtract sky
        subtract_sky(frame, skymodel)

    if args.calib!=None :
        log.info("calibrate")
        # read calibration
        fluxcalib=read_flux_calibration(args.calib)
        # apply calibration
        apply_flux_calibration(frame, fluxcalib)


    # save output
    write_frame(args.outfile, frame)

    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 21
0
def main(args):

    log = get_logger()

    log.info("starting")

    qa_prod = QA_Prod(args.specprod_dir)

    # Remake Frame QA?
    if args.make_frameqa > 0:
        log.info("(re)generating QA related to frames")
        if (args.make_frameqa % 4) >= 2:
            make_frame_plots = True
        else:
            make_frame_plots = False
        # Run
        qa_prod.make_frameqa(make_plots=make_frame_plots, clobber=args.clobber)

    # Slurp?
    if args.slurp:
        qa_prod.slurp(make=(args.make_frameqa > 0), remove=args.remove)

    # Channel histograms
    if args.channel_hist is not None:
        # imports
        from matplotlib.backends.backend_pdf import PdfPages
        from desispec.qa import qa_plots as dqqp
        #
        qa_prod.load_data()
        outfile = qa_prod.prod_name + '_chist.pdf'
        pp = PdfPages(outfile)
        # Default?
        if args.channel_hist == 'default':
            dqqp.prod_channel_hist(qa_prod,
                                   'FIBERFLAT',
                                   'MAX_RMS',
                                   pp=pp,
                                   close=False)
            dqqp.prod_channel_hist(qa_prod,
                                   'SKYSUB',
                                   'MED_RESID',
                                   xlim=(-1, 1),
                                   pp=pp,
                                   close=False)
            dqqp.prod_channel_hist(qa_prod,
                                   'FLUXCALIB',
                                   'MAX_ZP_OFF',
                                   pp=pp,
                                   close=False)
        # Finish
        print("Writing {:s}".format(outfile))
        pp.close()
Ejemplo n.º 22
0
 def test_log(self):
     desi_level=os.getenv("DESI_LOGLEVEL")
     for level in [None,log.DEBUG,log.INFO,log.WARNING,log.ERROR] :
         logger=log.get_logger(level)
         print("with the requested debugging level={0}".format(level))
         if desi_level is not None and (desi_level != "" ) :
             print("(but overuled by env. DESI_LOGLEVEL='{0}')".format(desi_level))
         print("--------------------------------------------------")
         logger.debug("This is a debugging message")
         logger.info("This is an information")
         logger.warning("This is an warning")
         logger.error("This is an error")
         logger.critical("This is a critical error")
Ejemplo n.º 23
0
 def test_log(self):
     desi_level = os.getenv("DESI_LOGLEVEL")
     for level in [None, log.DEBUG, log.INFO, log.WARNING, log.ERROR]:
         logger = log.get_logger(level)
         print("with the requested debugging level={0}".format(level))
         if desi_level is not None and (desi_level != ""):
             print("(but overuled by env. DESI_LOGLEVEL='{0}')".format(
                 desi_level))
         print("--------------------------------------------------")
         logger.debug("This is a debugging message")
         logger.info("This is an information")
         logger.warning("This is an warning")
         logger.error("This is an error")
         logger.critical("This is a critical error")
Ejemplo n.º 24
0
def apply_fiberflat(frame, fiberflat):
    """Apply fiberflat to frame.  Modifies frame.flux and frame.ivar
    
    Args:
        frame : `desispec.Frame` object
        fiberflat : `desispec.FiberFlat` object
        
    The frame is divided by the fiberflat, except where the fiberflat=0.

    frame.mask gets bit specmask.BADFIBERFLAT set where
      * fiberflat.fiberflat == 0
      * fiberflat.ivar == 0
      * fiberflat.mask != 0
    """
    log = get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, fiberflat.wave):
        message = "frame and fiberflat do not have the same wavelength arrays"
        log.critical(message)
        raise ValueError(message)
    """
     F'=F/C
     Var(F') = Var(F)/C**2 + F**2*(  d(1/C)/dC )**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*(1/C**2)**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*Var(C)/C**4
             = 1/(ivar(F)*C**2) + F**2/(ivar(C)*C**4)
    """
    #- shorthand
    ff = fiberflat
    sp = frame  #- sp=spectra for this frame

    #- update sp.ivar first since it depends upon the original sp.flux
    sp.ivar = (sp.ivar > 0) * (ff.ivar > 0) * (ff.fiberflat > 0) / (1. / (
        (sp.ivar + (sp.ivar == 0)) *
        (ff.fiberflat**2 +
         (ff.fiberflat == 0))) + sp.flux**2 / (ff.ivar * ff.fiberflat**4 +
                                               (ff.ivar * ff.fiberflat == 0)))

    #- Then update sp.flux, taking care not to divide by 0
    ii = np.where(ff.fiberflat > 0)
    sp.flux[ii] = sp.flux[ii] / ff.fiberflat[ii]

    badff = (ff.fiberflat == 0.0) | (ff.ivar == 0) | (ff.mask != 0)
    sp.mask[badff] |= specmask.BADFIBERFLAT

    log.info("done")
Ejemplo n.º 25
0
 def _redirect_stdout(to):
     sys.stdout.close() # + implicit flush()
     os.dup2(to.fileno(), fd) # fd writes to 'to' file
     sys.stdout = os.fdopen(fd, 'w') # Python writes to fd
     sys.stderr.close() # + implicit flush()
     os.dup2(to.fileno(), fde) # fd writes to 'to' file
     sys.stderr = os.fdopen(fde, 'w') # Python writes to fd
     # update desi logging to use new stdout
     log = get_logger()
     while len(log.handlers) > 0:
         h = log.handlers[0]
         log.removeHandler(h)
     # Add the current stdout.
     ch = logging.StreamHandler(sys.stdout)
     formatter = logging.Formatter('%(levelname)s:%(filename)s:%(lineno)s:%(funcName)s: %(message)s')
     ch.setFormatter(formatter)
     log.addHandler(ch)
Ejemplo n.º 26
0
 def _redirect_stdout(to):
     sys.stdout.close() # + implicit flush()
     os.dup2(to.fileno(), fd) # fd writes to 'to' file
     sys.stdout = os.fdopen(fd, 'w') # Python writes to fd
     sys.stderr.close() # + implicit flush()
     os.dup2(to.fileno(), fde) # fd writes to 'to' file
     sys.stderr = os.fdopen(fde, 'w') # Python writes to fd
     # update desi logging to use new stdout
     log = get_logger()
     while len(log.handlers) > 0:
         h = log.handlers[0]
         log.removeHandler(h)
     # Add the current stdout.
     ch = logging.StreamHandler(sys.stdout)
     formatter = logging.Formatter('%(levelname)s:%(filename)s:%(lineno)s:%(funcName)s: %(message)s')
     ch.setFormatter(formatter)
     log.addHandler(ch)
Ejemplo n.º 27
0
def apply_fiberflat(frame, fiberflat):
    """Apply fiberflat to frame.  Modifies frame.flux and frame.ivar
    
    Args:
        frame : `desispec.Frame` object
        fiberflat : `desispec.FiberFlat` object
        
    The frame is divided by the fiberflat, except where the fiberflat=0.

    frame.mask gets bit specmask.BADFIBERFLAT set where
      * fiberflat.fiberflat == 0
      * fiberflat.ivar == 0
      * fiberflat.mask != 0
    """
    log=get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, fiberflat.wave):
        message = "frame and fiberflat do not have the same wavelength arrays"
        log.critical(message)
        raise ValueError(message)

    """
     F'=F/C
     Var(F') = Var(F)/C**2 + F**2*(  d(1/C)/dC )**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*(1/C**2)**2*Var(C)
             = 1/(ivar(F)*C**2) + F**2*Var(C)/C**4
             = 1/(ivar(F)*C**2) + F**2/(ivar(C)*C**4)
    """
    #- shorthand
    ff = fiberflat
    sp = frame  #- sp=spectra for this frame
    
    #- update sp.ivar first since it depends upon the original sp.flux
    sp.ivar=(sp.ivar>0)*(ff.ivar>0)*(ff.fiberflat>0)/( 1./((sp.ivar+(sp.ivar==0))*(ff.fiberflat**2+(ff.fiberflat==0))) + sp.flux**2/(ff.ivar*ff.fiberflat**4+(ff.ivar*ff.fiberflat==0)) )

    #- Then update sp.flux, taking care not to divide by 0
    ii = np.where(ff.fiberflat > 0)
    sp.flux[ii] = sp.flux[ii] / ff.fiberflat[ii]

    badff = (ff.fiberflat == 0.0) | (ff.ivar == 0) | (ff.mask != 0)
    sp.mask[badff] |= specmask.BADFIBERFLAT

    log.info("done")
Ejemplo n.º 28
0
def sim(night, nspec=5, clobber=False):
    """
    Simulate data as part of the integration test.

    Args:
        night (str): YEARMMDD
        nspec (int, optional): number of spectra to include
        clobber (bool, optional): rerun steps even if outputs already exist
        
    Raises:
        RuntimeError if any script fails
    """
    log = logging.get_logger()

    # Create input fibermaps, spectra, and pixel-level raw data

    for expid, flavor in zip([0, 1, 2], ['flat', 'arc', 'dark']):
        cmd = "newexp-desi --flavor {flavor} --nspec {nspec} --night {night} --expid {expid}".format(
            expid=expid, flavor=flavor, nspec=nspec, night=night)
        fibermap = io.findfile('fibermap', night, expid)
        simspec = '{}/simspec-{:08d}.fits'.format(os.path.dirname(fibermap),
                                                  expid)
        inputs = []
        outputs = [fibermap, simspec]
        if pipe.runcmd(cmd, inputs=inputs, outputs=outputs,
                       clobber=clobber) != 0:
            raise RuntimeError(
                'pixsim newexp failed for {} exposure {}'.format(
                    flavor, expid))

        cmd = "pixsim-desi --preproc --nspec {nspec} --night {night} --expid {expid}".format(
            expid=expid, nspec=nspec, night=night)
        inputs = [fibermap, simspec]
        outputs = list()
        outputs.append(fibermap.replace('fibermap-', 'simpix-'))
        for camera in ['b0', 'r0', 'z0']:
            pixfile = io.findfile('pix', night, expid, camera)
            outputs.append(pixfile)
            #outputs.append(os.path.join(os.path.dirname(pixfile), os.path.basename(pixfile).replace('pix-', 'simpix-')))
        if pipe.runcmd(cmd, inputs=inputs, outputs=outputs,
                       clobber=clobber) != 0:
            raise RuntimeError('pixsim failed for {} exposure {}'.format(
                flavor, expid))

    return
Ejemplo n.º 29
0
def main() :

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--infile', type = str, default = None, required=True,
                        help = 'path of DESI exposure frame fits file')
    parser.add_argument('--fibermap', type = str, default = None, required=True,
                        help = 'path of DESI exposure frame fits file')
    parser.add_argument('--fiberflat', type = str, default = None, required=True,
                        help = 'path of DESI fiberflat fits file')
    parser.add_argument('--outfile', type = str, default = None, required=True,
                        help = 'path of DESI sky fits file')


    args = parser.parse_args()
    log=get_logger()

    log.info("starting")

    # read exposure to load data and get range of spectra
    frame = read_frame(args.infile)
    specmin=frame.header["SPECMIN"]
    specmax=frame.header["SPECMAX"]

    # read fibermap to locate sky fibers
    fibermap = read_fibermap(args.fibermap)
    selection=np.where((fibermap["OBJTYPE"]=="SKY")&(fibermap["FIBER"]>=specmin)&(fibermap["FIBER"]<=specmax))[0]
    if selection.size == 0 :
        log.error("no sky fiber in fibermap %s"%args.fibermap)
        sys.exit(12)

    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat to sky fibers
    apply_fiberflat(frame, fiberflat)

    # compute sky model
    skymodel = compute_sky(frame, fibermap)

    # write result
    write_sky(args.outfile, skymodel, frame.header)

    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 30
0
def check_env():
    """
    Check required environment variables; raise RuntimeException if missing
    """
    log = logging.get_logger()
    #- template locations
    missing_env = False
    if 'DESI_BASIS_TEMPLATES' not in os.environ:
        log.warning(
            'missing $DESI_BASIS_TEMPLATES needed for simulating spectra'.
            format(name))
        missing_env = True

    if not os.path.isdir(os.getenv('DESI_BASIS_TEMPLATES')):
        log.warning('missing $DESI_BASIS_TEMPLATES directory')
        log.warning(
            'e.g. see NERSC:/project/projectdirs/desi/spectro/templates/basis_templates/v2.2'
        )
        missing_env = True

    for name in ('DESI_SPECTRO_SIM', 'DESI_SPECTRO_REDUX', 'PIXPROD',
                 'SPECPROD', 'DESIMODEL'):
        if name not in os.environ:
            log.warning("missing ${0}".format(name))
            missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD/")
        log.warning("    Raw data read from $DESI_SPECTRO_DATA/")
        log.warning(
            "    Spectro pipeline output written to $DESI_SPECTRO_REDUX/$SPECPROD/"
        )
        log.warning("    Templates are read from $DESI_BASIS_TEMPLATES")

    #- Wait until end to raise exception so that we report everything that
    #- is missing before actually failing
    if missing_env:
        log.critical("missing env vars; exiting without running pipeline")
        sys.exit(1)

    #- Override $DESI_SPECTRO_DATA to match $DESI_SPECTRO_SIM/$PIXPROD
    os.environ['DESI_SPECTRO_DATA'] = os.path.join(
        os.getenv('DESI_SPECTRO_SIM'), os.getenv('PIXPROD'))
Ejemplo n.º 31
0
def retry_task(failpath, newopts=None):
    log = get_logger()

    if not os.path.isfile(failpath):
        raise RuntimeError("failure yaml file {} does not exist".format(failpath))

    fyml = None
    with open(failpath, 'r') as f:
        fyml = yaml.load(f)

    step = fyml['step']
    rawdir = fyml['rawdir']
    proddir = fyml['proddir']
    name = fyml['task']
    grph = fyml['graph']
    origopts = fyml['opts']
    nproc = fyml['procs']

    comm = None
    rank = 0

    if nproc > 1:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        nworld = comm.size
        rank = comm.rank
        if nworld != nproc:
            if rank == 0:
                log.warn("WARNING: original task was run with {} processes, re-running with {} instead".format(nproc, nworld))

    opts = origopts
    if newopts is not None:
        log.warn("WARNING: overriding original options")
        opts = newopts

    try:
        run_task(step, rawdir, proddir, grph, opts, comm=comm)
    except:
        log.error("Retry Failed")
        raise
    else:
        if rank == 0:
            os.remove(failpath)
    return
Ejemplo n.º 32
0
def decorrelate(Cinv):
    """Decorrelate an inverse covariance using the matrix square root.

    Implements the decorrelation part of the spectroperfectionism algorithm described in
    Bolton & Schlegel 2009 (BS) http://arxiv.org/abs/0911.2689, w uses the matrix square root of
    Cinv to form a diagonal basis. This is generally a better choice than the eigenvector or
    Cholesky bases since it leads to more localized basis vectors, as described in
    Hamilton & Tegmark 2000 http://arxiv.org/abs/astro-ph/9905192.

    Args:
        Cinv(numpy.ndarray): Square array of inverse covariance matrix elements. The input can
            either be a scipy.sparse format or else a regular (dense) numpy array, but a
            sparse format will be internally converted to a dense matrix so there is no
            performance advantage.

    Returns:
        tuple: Tuple ivar,R of uncorrelated flux inverse variances and the corresponding
            resolution matrix. These have shapes (nflux,) and (nflux,nflux) respectively.
            The rows of R give the resolution-convolved responses to unit flux for each
            wavelength bin. Note that R is returned as a regular (dense) numpy array but
            will normally have non-zero values concentrated near the diagonal.
    """
    log = get_logger()
    # Clean up any roundoff errors by forcing Cinv to be symmetric.
    Cinv = 0.5*(Cinv + Cinv.T)
    # Convert to a dense matrix if necessary.
    if scipy.sparse.issparse(Cinv):
        Cinv = Cinv.todense()
    # Calculate the matrix square root. Note that we do not use scipy.linalg.sqrtm since
    # the method below is about 2x faster for a positive definite matrix.
    L,X = scipy.linalg.eigh(Cinv)
    # Check for negative eigenvalues.
    nbad = np.count_nonzero(L < 0)
    if nbad > 0:
        log.warning('zeroing {0:d} negative eigenvalue(s).'.format(nbad))
        L[L < 0] = 0.
    # Calculate the matrix square root Q such that Cinv = Q.Q
    Q = X.dot(np.diag(np.sqrt(L)).dot(X.T))
    # Calculate and return the corresponding resolution matrix and diagonal flux errors.
    s = np.sum(Q,axis=1)
    R = Q/s[:,np.newaxis]
    ivar = s**2
    return ivar,R
Ejemplo n.º 33
0
def decorrelate(Cinv):
    """Decorrelate an inverse covariance using the matrix square root.

    Implements the decorrelation part of the spectroperfectionism algorithm described in
    Bolton & Schlegel 2009 (BS) http://arxiv.org/abs/0911.2689, w uses the matrix square root of
    Cinv to form a diagonal basis. This is generally a better choice than the eigenvector or
    Cholesky bases since it leads to more localized basis vectors, as described in
    Hamilton & Tegmark 2000 http://arxiv.org/abs/astro-ph/9905192.

    Args:
        Cinv(numpy.ndarray): Square array of inverse covariance matrix elements. The input can
            either be a scipy.sparse format or else a regular (dense) numpy array, but a
            sparse format will be internally converted to a dense matrix so there is no
            performance advantage.

    Returns:
        tuple: Tuple ivar,R of uncorrelated flux inverse variances and the corresponding
            resolution matrix. These have shapes (nflux,) and (nflux,nflux) respectively.
            The rows of R give the resolution-convolved responses to unit flux for each
            wavelength bin. Note that R is returned as a regular (dense) numpy array but
            will normally have non-zero values concentrated near the diagonal.
    """
    log = get_logger()
    # Clean up any roundoff errors by forcing Cinv to be symmetric.
    Cinv = 0.5 * (Cinv + Cinv.T)
    # Convert to a dense matrix if necessary.
    if scipy.sparse.issparse(Cinv):
        Cinv = Cinv.todense()
    # Calculate the matrix square root. Note that we do not use scipy.linalg.sqrtm since
    # the method below is about 2x faster for a positive definite matrix.
    L, X = scipy.linalg.eigh(Cinv)
    # Check for negative eigenvalues.
    nbad = np.count_nonzero(L < 0)
    if nbad > 0:
        log.warning('zeroing {0:d} negative eigenvalue(s).'.format(nbad))
        L[L < 0] = 0.
    # Calculate the matrix square root Q such that Cinv = Q.Q
    Q = X.dot(np.diag(np.sqrt(L)).dot(X.T))
    # Calculate and return the corresponding resolution matrix and diagonal flux errors.
    s = np.sum(Q, axis=1)
    R = Q / s[:, np.newaxis]
    ivar = s**2
    return ivar, R
Ejemplo n.º 34
0
def subtract_sky(frame, skymodel) :
    """Subtract skymodel from frame, altering frame.flux, .ivar, and .mask
    """
    assert frame.nspec == skymodel.nspec
    assert frame.nwave == skymodel.nwave

    log=get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, skymodel.wave):
        message = "frame and sky not on same wavelength grid"
        log.error(message)
        raise ValueError(message)

    frame.flux -= skymodel.flux
    frame.ivar = util.combine_ivar(frame.ivar, skymodel.ivar)
    frame.mask |= skymodel.mask

    log.info("done")
Ejemplo n.º 35
0
def subtract_sky(frame, skymodel):
    """Subtract skymodel from frame, altering frame.flux, .ivar, and .mask
    """
    assert frame.nspec == skymodel.nspec
    assert frame.nwave == skymodel.nwave

    log = get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, skymodel.wave):
        message = "frame and sky not on same wavelength grid"
        log.error(message)
        raise ValueError(message)

    frame.flux -= skymodel.flux
    frame.ivar = util.combine_ivar(frame.ivar, skymodel.ivar)
    frame.mask |= skymodel.mask

    log.info("done")
Ejemplo n.º 36
0
def main() :

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--infile', type = str, default = None, required=True,
                        help = 'path of DESI frame fits file corresponding to a continuum lamp exposure')
    parser.add_argument('--outfile', type = str, default = None, required=True,
                        help = 'path of DESI fiberflat fits file')


    args = parser.parse_args()
    log=get_logger()
    
    log.info("starting")

    frame = read_frame(args.infile)
    fiberflat = compute_fiberflat(frame)
    write_fiberflat(args.outfile, fiberflat, frame.header)

    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 37
0
def main(args) :

    log=get_logger()

    log.info("starting")

    qa_prod = QA_Prod(args.specprod_dir)

    # Remake Frame QA?
    if args.make_frameqa > 0:
        log.info("(re)generating QA related to frames")
        if (args.make_frameqa % 4) >= 2:
            make_frame_plots = True
        else:
            make_frame_plots = False
        # Run
        qa_prod.make_frameqa(make_plots=make_frame_plots, clobber=args.clobber)

    # Slurp?
    if args.slurp:
        qa_prod.slurp(make=(args.make_frameqa > 0), remove=args.remove)

    # Channel histograms
    if args.channel_hist is not None:
        # imports
        from matplotlib.backends.backend_pdf import PdfPages
        from desispec.qa import qa_plots as dqqp
        #
        qa_prod.load_data()
        outfile = qa_prod.prod_name+'_chist.pdf'
        pp = PdfPages(outfile)
        # Default?
        if args.channel_hist == 'default':
            dqqp.prod_channel_hist(qa_prod, 'FIBERFLAT', 'MAX_RMS', pp=pp, close=False)
            dqqp.prod_channel_hist(qa_prod, 'SKYSUB', 'MED_RESID', xlim=(-1,1), pp=pp, close=False)
            dqqp.prod_channel_hist(qa_prod, 'FLUXCALIB', 'MAX_ZP_OFF', pp=pp, close=False)
        # Finish
        print("Writing {:s}".format(outfile))
        pp.close()
Ejemplo n.º 38
0
def spline_fit(output_wave,input_wave,input_flux,required_resolution,input_ivar=None,order=3):
    """Performs a spline fit.
    """
    if input_ivar is not None :
        selection=np.where(input_ivar>0)[0]
        if selection.size < 2 :
            log=get_logger()
            log.error("cannot do spline fit because only {0:d} values with ivar>0".format(selection.size))
            raise Error
        w1=input_wave[selection[0]]
        w2=input_wave[selection[-1]]
    else :
        w1=input_wave[0]
        w2=input_wave[-1]

    res=required_resolution
    n=int((w2-w1)/res)
    res=(w2-w1)/(n+1)
    knots=w1+res*(0.5+np.arange(n))
    toto=scipy.interpolate.splrep(input_wave,input_flux,w=input_ivar,k=order,task=-1,t=knots)
    output_flux = scipy.interpolate.splev(output_wave,toto)
    return output_flux
Ejemplo n.º 39
0
def main(args):

    log = get_logger()

    log.info("starting")

    # read exposure to load data and get range of spectra
    frame = read_frame(args.infile)
    specmin, specmax = np.min(frame.fibers), np.max(frame.fibers)

    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat to sky fibers
    apply_fiberflat(frame, fiberflat)

    # compute sky model
    skymodel = compute_sky(frame)

    # QA
    if (args.qafile is not None) or (args.qafig is not None):
        log.info("performing skysub QA")
        # Load
        qaframe = load_qa_frame(args.qafile,
                                frame,
                                flavor=frame.meta['FLAVOR'])
        # Run
        qaframe.run_qa('SKYSUB', (frame, skymodel))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_skyres(args.qafig, frame, skymodel, qaframe)

    # write result
    write_sky(args.outfile, skymodel, frame.meta)
    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 40
0
def sim(night, nspec=5, clobber=False):
    """
    Simulate data as part of the integration test.

    Args:
        night (str): YEARMMDD
        nspec (int, optional): number of spectra to include
        clobber (bool, optional): rerun steps even if outputs already exist
        
    Raises:
        RuntimeError if any script fails
    """
    log = logging.get_logger()

    # Create input fibermaps, spectra, and pixel-level raw data

    for expid, flavor in zip([0,1,2], ['flat', 'arc', 'dark']):
        cmd = "newexp-desi --flavor {flavor} --nspec {nspec} --night {night} --expid {expid}".format(
            expid=expid, flavor=flavor, nspec=nspec, night=night)
        fibermap = io.findfile('fibermap', night, expid)
        simspec = '{}/simspec-{:08d}.fits'.format(os.path.dirname(fibermap), expid)
        inputs = []
        outputs = [fibermap, simspec]
        if pipe.runcmd(cmd, inputs=inputs, outputs=outputs, clobber=clobber) != 0:
            raise RuntimeError('pixsim newexp failed for {} exposure {}'.format(flavor, expid))

        cmd = "pixsim-desi --preproc --nspec {nspec} --night {night} --expid {expid}".format(expid=expid, nspec=nspec, night=night)
        inputs = [fibermap, simspec]
        outputs = list()
        outputs.append(fibermap.replace('fibermap-', 'simpix-'))
        for camera in ['b0', 'r0', 'z0']:
            pixfile = io.findfile('pix', night, expid, camera)
            outputs.append(pixfile)
            #outputs.append(os.path.join(os.path.dirname(pixfile), os.path.basename(pixfile).replace('pix-', 'simpix-')))
        if pipe.runcmd(cmd, inputs=inputs, outputs=outputs, clobber=clobber) != 0:
            raise RuntimeError('pixsim failed for {} exposure {}'.format(flavor, expid))

    return
Ejemplo n.º 41
0
def check_env():
    """
    Check required environment variables; raise RuntimeException if missing
    """
    log = get_logger()
    #- template locations
    missing_env = False
    if 'DESI_BASIS_TEMPLATES' not in os.environ:
        log.warning('missing $DESI_BASIS_TEMPLATES needed for simulating spectra')
        missing_env = True

    if not os.path.isdir(os.getenv('DESI_BASIS_TEMPLATES')):
        log.warning('missing $DESI_BASIS_TEMPLATES directory')
        log.warning('e.g. see NERSC:/project/projectdirs/desi/spectro/templates/basis_templates/v1.0')
        missing_env = True

    for name in (
        'DESI_SPECTRO_SIM', 'DESI_SPECTRO_REDUX', 'PIXPROD', 'SPECPROD'):
        if name not in os.environ:
            log.warning("missing ${0}".format(name))
            missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD/")
        log.warning("    Raw data read from $DESI_SPECTRO_DATA/")
        log.warning("    Spectro pipeline output written to $DESI_SPECTRO_REDUX/$SPECPROD/")
        log.warning("    Templates are read from $DESI_BASIS_TEMPLATES")

    #- Wait until end to raise exception so that we report everything that
    #- is missing before actually failing
    if missing_env:
        log.critical("missing env vars; exiting without running pipeline")
        sys.exit(1)

    #- Override $DESI_SPECTRO_DATA to match $DESI_SPECTRO_SIM/$PIXPROD
    os.environ['DESI_SPECTRO_DATA'] = os.path.join(os.getenv('DESI_SPECTRO_SIM'), os.getenv('PIXPROD'))
Ejemplo n.º 42
0
def main(args) :

    log=get_logger()

    log.info("starting")

    # read exposure to load data and get range of spectra
    frame = read_frame(args.infile)
    specmin, specmax = np.min(frame.fibers), np.max(frame.fibers)

    # read fiberflat
    fiberflat = read_fiberflat(args.fiberflat)

    # apply fiberflat to sky fibers
    apply_fiberflat(frame, fiberflat)

    # compute sky model
    skymodel = compute_sky(frame)

    # QA
    if (args.qafile is not None) or (args.qafig is not None):
        log.info("performing skysub QA")
        # Load
        qaframe = load_qa_frame(args.qafile, frame, flavor=frame.meta['FLAVOR'])
        # Run
        qaframe.run_qa('SKYSUB', (frame, skymodel))
        # Write
        if args.qafile is not None:
            write_qa_frame(args.qafile, qaframe)
            log.info("successfully wrote {:s}".format(args.qafile))
        # Figure(s)
        if args.qafig is not None:
            qa_plots.frame_skyres(args.qafig, frame, skymodel, qaframe)

    # write result
    write_sky(args.outfile, skymodel, frame.meta)
    log.info("successfully wrote %s"%args.outfile)
Ejemplo n.º 43
0
def main(args):

    log = get_logger()

    if (args.fiberflat is None) and (args.sky is None) and (args.calib is
                                                            None):
        log.critical('no --fiberflat, --sky, or --calib; nothing to do ?!?')
        sys.exit(12)

    frame = read_frame(args.infile)

    if args.fiberflat != None:
        log.info("apply fiberflat")
        # read fiberflat
        fiberflat = read_fiberflat(args.fiberflat)

        # apply fiberflat to sky fibers
        apply_fiberflat(frame, fiberflat)

    if args.sky != None:
        log.info("subtract sky")
        # read sky
        skymodel = read_sky(args.sky)
        # subtract sky
        subtract_sky(frame, skymodel)

    if args.calib != None:
        log.info("calibrate")
        # read calibration
        fluxcalib = read_flux_calibration(args.calib)
        # apply calibration
        apply_flux_calibration(frame, fluxcalib)

    # save output
    write_frame(args.outfile, frame, units='1e-17 erg/(s cm2 A)')

    log.info("successfully wrote %s" % args.outfile)
Ejemplo n.º 44
0
def _test_zbest_io():
    """This should be moved to a separate test file?  Yes, it should.
    """
    import os
    log = get_logger()
    nspec, nflux = 10, 20
    wave = np.arange(nflux)
    flux = np.random.uniform(size=(nspec, nflux))
    ivar = np.random.uniform(size=(nspec, nflux))
    zfind1 = ZfindBase(wave, flux, ivar)

    brickname = '1234p567'
    targetids = np.random.randint(0, 12345678, size=nspec)

    outfile = 'zbest_test.fits'
    write_zbest(outfile, brickname, targetids, zfind1)
    zfind2 = read_zbest(outfile)

    assert np.all(zfind2.z == zfind1.z)
    assert np.all(zfind2.zerr == zfind1.zerr)
    assert np.all(zfind2.zwarn == zfind1.zwarn)
    assert np.all(zfind2.type == zfind1.type)
    assert np.all(zfind2.subtype == zfind1.subtype)
    assert np.all(zfind2.brickname == brickname)
    assert np.all(zfind2.targetid == targetids)

    write_zbest(outfile, brickname, targetids, zfind1, zspec=True)
    zfind3 = read_zbest(outfile)
    assert np.all(zfind3.wave == zfind1.wave)
    assert np.all(zfind3.flux == zfind1.flux)
    assert np.all(zfind3.ivar == zfind1.ivar)
    assert np.all(zfind3.model == zfind1.model)

    log.info("looks OK to me")

    os.remove(outfile)
Ejemplo n.º 45
0
def main(args) :

    if args.outfile is not None :
        outfile=args.outfile
    else :
        outfile=args.infile

    log = get_logger()
    log.info("starting finding cosmics in %s"%args.infile)

    img=image.read_image(args.infile)

    if args.ignore_cosmic_ccdmask :
        log.warning("ignore cosmic ccdmask for test")
        log.debug("ccdmask.COSMIC = %d"%ccdmask.COSMIC)
        cosmic_ray_prexisting_mask = img.mask & ccdmask.COSMIC
        img._mask &= ~ccdmask.COSMIC  #- turn off cosmic mask

    reject_cosmic_rays(img)

    log.info("writing data and new mask in %s"%outfile)
    image.write_image(outfile, img, meta=img.meta)

    log.info("done")
Ejemplo n.º 46
0
def _test_zbest_io():
    """This should be moved to a separate test file?  Yes, it should.
    """
    import os
    log=get_logger()
    nspec, nflux = 10, 20
    wave = np.arange(nflux)
    flux = np.random.uniform(size=(nspec, nflux))
    ivar = np.random.uniform(size=(nspec, nflux))
    zfind1 = ZfindBase(wave, flux, ivar)

    brickname = '1234p567'
    targetids = np.random.randint(0,12345678, size=nspec)

    outfile = 'zbest_test.fits'
    write_zbest(outfile, brickname, targetids, zfind1)
    zfind2 = read_zbest(outfile)

    assert np.all(zfind2.z == zfind1.z)
    assert np.all(zfind2.zerr == zfind1.zerr)
    assert np.all(zfind2.zwarn == zfind1.zwarn)
    assert np.all(zfind2.type == zfind1.type)
    assert np.all(zfind2.subtype == zfind1.subtype)
    assert np.all(zfind2.brickname == brickname)
    assert np.all(zfind2.targetid == targetids)

    write_zbest(outfile, brickname, targetids, zfind1, zspec=True)
    zfind3 = read_zbest(outfile)
    assert np.all(zfind3.wave == zfind1.wave)
    assert np.all(zfind3.flux == zfind1.flux)
    assert np.all(zfind3.ivar == zfind1.ivar)
    assert np.all(zfind3.model == zfind1.model)

    log.info("looks OK to me")

    os.remove(outfile)
Ejemplo n.º 47
0
def main(args):

    if args.outfile is not None:
        outfile = args.outfile
    else:
        outfile = args.infile

    log = get_logger()
    log.info("starting finding cosmics in %s" % args.infile)

    img = image.read_image(args.infile)

    if args.ignore_cosmic_ccdmask:
        log.warning("ignore cosmic ccdmask for test")
        log.debug("ccdmask.COSMIC = %d" % ccdmask.COSMIC)
        cosmic_ray_prexisting_mask = img.mask & ccdmask.COSMIC
        img._mask &= ~ccdmask.COSMIC  #- turn off cosmic mask

    reject_cosmic_rays(img)

    log.info("writing data and new mask in %s" % outfile)
    image.write_image(outfile, img, meta=img.meta)

    log.info("done")
Ejemplo n.º 48
0
def qa_skysub(param, frame, skymodel, quick_look=False):
    """Calculate QA on SkySubtraction

    Note: Pixels rejected in generating the SkyModel (as above), are
    not rejected in the stats calculated here.  Would need to carry
    along current_ivar to do so.

    Args:
        param : dict of QA parameters
        frame : desispec.Frame object
        skymodel : desispec.SkyModel object
        quick_look : bool, optional
          If True, do QuickLook specific QA (or avoid some)
    Returns:
        qadict: dict of QA outputs
          Need to record simple Python objects for yaml (str, float, int)
    """
    log = get_logger()

    # Output dict
    qadict = {}
    qadict['NREJ'] = int(skymodel.nrej)

    # Grab sky fibers on this frame
    skyfibers = np.where(frame.fibermap['OBJTYPE'] == 'SKY')[0]
    assert np.max(skyfibers) < 500  #- indices, not fiber numbers
    nfibers = len(skyfibers)
    qadict['NSKY_FIB'] = int(nfibers)

    current_ivar = frame.ivar[skyfibers].copy()
    flux = frame.flux[skyfibers]

    # Subtract
    res = flux - skymodel.flux[skyfibers]  # Residuals
    res_ivar = util.combine_ivar(current_ivar, skymodel.ivar[skyfibers])

    # Chi^2 and Probability
    chi2_fiber = np.sum(res_ivar * (res**2), 1)
    chi2_prob = np.zeros(nfibers)
    for ii in range(nfibers):
        # Stats
        dof = np.sum(res_ivar[ii, :] > 0.)
        chi2_prob[ii] = scipy.stats.chisqprob(chi2_fiber[ii], dof)
    # Bad models
    qadict['NBAD_PCHI'] = int(np.sum(chi2_prob < param['PCHI_RESID']))
    if qadict['NBAD_PCHI'] > 0:
        log.warn("Bad Sky Subtraction in {:d} fibers".format(
            qadict['NBAD_PCHI']))

    # Median residual
    qadict['MED_RESID'] = float(np.median(res))  # Median residual (counts)
    log.info("Median residual for sky fibers = {:g}".format(
        qadict['MED_RESID']))

    # Residual percentiles
    perc = dustat.perc(res, per=param['PER_RESID'])
    qadict['RESID_PER'] = [float(iperc) for iperc in perc]

    # Mean Sky Continuum from all skyfibers
    # need to limit in wavelength?

    if quick_look:
        continuum = scipy.ndimage.filters.median_filter(
            flux, 200)  # taking 200 bins (somewhat arbitrarily)
        mean_continuum = np.zeros(flux.shape[1])
        for ii in range(flux.shape[1]):
            mean_continuum[ii] = np.mean(continuum[:, ii])
        qadict['MEAN_CONTIN'] = mean_continuum

    # Median Signal to Noise on sky subtracted spectra
    # first do the subtraction:
    if quick_look:
        fframe = frame  # make a copy
        sskymodel = skymodel  # make a copy
        subtract_sky(fframe, sskymodel)
        medsnr = np.zeros(fframe.flux.shape[0])
        totsnr = np.zeros(fframe.flux.shape[0])
        for ii in range(fframe.flux.shape[0]):
            signalmask = fframe.flux[ii, :] > 0
            # total snr considering bin by bin uncorrelated S/N
            snr = fframe.flux[ii, signalmask] * np.sqrt(
                fframe.ivar[ii, signalmask])
            medsnr[ii] = np.median(snr)
            totsnr[ii] = np.sqrt(np.sum(snr**2))
        qadict['MED_SNR'] = medsnr  # for each fiber
        qadict['TOT_SNR'] = totsnr  # for each fiber

    # Return
    return qadict
Ejemplo n.º 49
0
def compute_fiberflat(frame, nsig_clipping=4.) :
    """Compute fiber flat by deriving an average spectrum and dividing all fiber data by this average.
    Input data are expected to be on the same wavelenght grid, with uncorrelated noise.
    They however do not have exactly the same resolution.

    args:
        frame (desispec.Frame): input Frame object with attributes
            wave, flux, ivar, resolution_data
        nsig_clipping : [optional] sigma clipping value for outlier rejection

    returns tuple (fiberflat, ivar, mask, meanspec):
        fiberflat : 2D[nwave, nflux] fiberflat (data have to be divided by this to be flatfielded)
        ivar : inverse variance of that fiberflat
        mask : 0=ok >0 if problems
        meanspec : deconvolved mean spectrum

    - we first iteratively :
       - compute a deconvolved mean spectrum
       - compute a fiber flat using the resolution convolved mean spectrum for each fiber
       - smooth the fiber flat along wavelength
       - clip outliers

    - then we compute a fiberflat at the native fiber resolution (not smoothed)

    - the routine returns the fiberflat, its inverse variance , mask, and the deconvolved mean spectrum

    - the fiberflat is the ratio data/mean , so this flat should be divided to the data

    NOTE THAT THIS CODE HAS NOT BEEN TESTED WITH ACTUAL FIBER TRANSMISSION VARIATIONS,
    OUTLIER PIXELS, DEAD COLUMNS ...
    """
    log=get_logger()
    log.info("starting")

    #
    # chi2 = sum_(fiber f) sum_(wavelenght i) w_fi ( D_fi - F_fi (R_f M)_i )
    #
    # where
    # w = inverse variance
    # D = flux data (at the resolution of the fiber)
    # F = smooth fiber flat
    # R = resolution data
    # M = mean deconvolved spectrum
    #
    # M = A^{-1} B
    # with
    # A_kl = sum_(fiber f) sum_(wavelenght i) w_fi F_fi^2 (R_fki R_fli)
    # B_k = sum_(fiber f) sum_(wavelenght i) w_fi D_fi F_fi R_fki
    #
    # defining R'_fi = sqrt(w_fi) F_fi R_fi
    # and      D'_fi = sqrt(w_fi) D_fi
    #
    # A = sum_(fiber f) R'_f R'_f^T
    # B = sum_(fiber f) R'_f D'_f
    # (it's faster that way, and we try to use sparse matrices as much as possible)
    #

    #- Shortcuts
    nwave=frame.nwave
    nfibers=frame.nspec
    wave = frame.wave.copy()  #- this will become part of output too
    flux = frame.flux
    ivar = frame.ivar


    # iterative fitting and clipping to get precise mean spectrum
    current_ivar=ivar.copy()


    smooth_fiberflat=np.ones((frame.flux.shape))
    chi2=np.zeros((flux.shape))


    sqrtwflat=np.sqrt(current_ivar)*smooth_fiberflat
    sqrtwflux=np.sqrt(current_ivar)*flux


    # test
    #nfibers=20
    nout_tot=0
    for iteration in range(20) :

        # fit mean spectrum
        A=scipy.sparse.lil_matrix((nwave,nwave)).tocsr()
        B=np.zeros((nwave))

        # diagonal sparse matrix with content = sqrt(ivar)*flat of a given fiber
        SD=scipy.sparse.lil_matrix((nwave,nwave))

        # loop on fiber to handle resolution
        for fiber in range(nfibers) :
            if fiber%10==0 :
                log.info("iter %d fiber %d"%(iteration,fiber))

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]

            # diagonal sparse matrix with content = sqrt(ivar)*flat
            SD.setdiag(sqrtwflat[fiber])

            sqrtwflatR = SD*R # each row r of R is multiplied by sqrtwflat[r]

            A = A+(sqrtwflatR.T*sqrtwflatR).tocsr()
            B += sqrtwflatR.T*sqrtwflux[fiber]

        log.info("iter %d solving"%iteration)

        mean_spectrum=cholesky_solve(A.todense(),B)

        log.info("iter %d smoothing"%iteration)

        # fit smooth fiberflat and compute chi2
        smoothing_res=100. #A

        for fiber in range(nfibers) :

            #if fiber%10==0 :
            #    log.info("iter %d fiber %d (smoothing)"%(iteration,fiber))

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]

            #M = np.array(np.dot(R.todense(),mean_spectrum)).flatten()
            M = R.dot(mean_spectrum)

            F = flux[fiber]/(M+(M==0))
            smooth_fiberflat[fiber]=spline_fit(wave,wave,F,smoothing_res,current_ivar[fiber]*(M!=0))
            chi2[fiber]=current_ivar[fiber]*(flux[fiber]-smooth_fiberflat[fiber]*M)**2

        log.info("rejecting")

        nout_iter=0
        if iteration<1 :
            # only remove worst outlier per wave
            # apply rejection iteratively, only one entry per wave among fibers
            # find waves with outlier (fastest way)
            nout_per_wave=np.sum(chi2>nsig_clipping**2,axis=0)
            selection=np.where(nout_per_wave>0)[0]
            for i in selection :
                worst_entry=np.argmax(chi2[:,i])
                current_ivar[worst_entry,i]=0
                sqrtwflat[worst_entry,i]=0
                sqrtwflux[worst_entry,i]=0
                nout_iter += 1

        else :
            # remove all of them at once
            bad=(chi2>nsig_clipping**2)
            current_ivar *= (bad==0)
            sqrtwflat *= (bad==0)
            sqrtwflux *= (bad==0)
            nout_iter += np.sum(bad)

        nout_tot += nout_iter

        sum_chi2=float(np.sum(chi2))
        ndf=int(np.sum(chi2>0)-nwave-nfibers*(nwave/smoothing_res))
        chi2pdf=0.
        if ndf>0 :
            chi2pdf=sum_chi2/ndf
        log.info("iter #%d chi2=%f ndf=%d chi2pdf=%f nout=%d"%(iteration,sum_chi2,ndf,chi2pdf,nout_iter))

        # normalize to get a mean fiberflat=1
        mean=np.mean(smooth_fiberflat,axis=0)
        smooth_fiberflat = smooth_fiberflat/mean
        mean_spectrum    = mean_spectrum*mean



        if nout_iter == 0 :
            break

    log.info("nout tot=%d"%nout_tot)

    # now use mean spectrum to compute flat field correction without any smoothing
    # because sharp feature can arise if dead columns

    fiberflat=np.ones((flux.shape))
    fiberflat_ivar=np.zeros((flux.shape))
    mask=np.zeros((flux.shape)).astype(long)  # SOMEONE CHECK THIS !

    fiberflat_mask=12 # place holder for actual mask bit when defined

    nsig_for_mask=4 # only mask out 4 sigma outliers

    for fiber in range(nfibers) :
        ### R = Resolution(resolution_data[fiber])
        R = frame.R[fiber]
        M = np.array(np.dot(R.todense(),mean_spectrum)).flatten()
        fiberflat[fiber] = (M!=0)*flux[fiber]/(M+(M==0)) + (M==0)
        fiberflat_ivar[fiber] = ivar[fiber]*M**2
        smooth_fiberflat=spline_fit(wave,wave,fiberflat[fiber],smoothing_res,current_ivar[fiber]*M**2*(M!=0))
        bad=np.where(fiberflat_ivar[fiber]*(fiberflat[fiber]-smooth_fiberflat)**2>nsig_for_mask**2)[0]
        if bad.size>0 :
            mask[fiber,bad] += fiberflat_mask

    return FiberFlat(wave, fiberflat, fiberflat_ivar, mask, mean_spectrum)    
Ejemplo n.º 50
0
def main(args):

    if args.verbose:
        log = get_logger(DEBUG)
    else:
        log = get_logger()

    if args.night is None:
        log.critical('Missing required night argument.')
        return -1

    # Initialize a dictionary of Brick objects indexed by '<band>_<brick-id>' strings.
    bricks = {}
    try:
        # Loop over exposures available for this night.
        for exposure in desispec.io.get_exposures(args.night,
                                                  specprod_dir=args.specprod):
            # Ignore exposures with no fibermap, assuming they are calibration data.
            fibermap_path = desispec.io.findfile(filetype='fibermap',
                                                 night=args.night,
                                                 expid=exposure,
                                                 specprod_dir=args.specprod)
            if not os.path.exists(fibermap_path):
                log.debug('Skipping exposure %08d with no fibermap.' %
                          exposure)
                continue
            # Open the fibermap.
            fibermap_data = desispec.io.read_fibermap(fibermap_path)
            brick_names = set(fibermap_data['BRICKNAME'])
            # Loop over per-camera cframes available for this exposure.
            cframes = desispec.io.get_files(filetype='cframe',
                                            night=args.night,
                                            expid=exposure,
                                            specprod_dir=args.specprod)
            log.debug(
                'Exposure %08d covers %d bricks and has cframes for %s.' %
                (exposure, len(brick_names), ','.join(cframes.keys())))
            for camera, cframe_path in cframes.iteritems():
                band, spectro_id = camera[0], int(camera[1:])
                this_camera = (fibermap_data['SPECTROID'] == spectro_id)
                # Read this cframe file.
                frame = desispec.io.read_frame(cframe_path)
                # Loop over bricks.
                for brick_name in brick_names:
                    # Lookup the fibers belong to this brick.
                    this_brick = (fibermap_data['BRICKNAME'] == brick_name)
                    brick_data = fibermap_data[this_camera & this_brick]
                    fibers = np.mod(brick_data['FIBER'], 500)
                    if len(fibers) == 0:
                        continue
                    brick_key = '%s_%s' % (band, brick_name)
                    # Open the brick file if this is the first time we are using it.
                    if brick_key not in bricks:
                        brick_path = desispec.io.findfile('brick',
                                                          brickname=brick_name,
                                                          band=band)
                        header = dict(
                            BRICKNAM=(brick_name, 'Imaging brick name'),
                            CHANNEL=(band, 'Spectrograph channel [b,r,z]'),
                        )
                        bricks[brick_key] = desispec.io.brick.Brick(
                            brick_path, mode='update', header=header)
                    # Add these fibers to the brick file. Note that the wavelength array is
                    # not per-fiber, so we do not slice it before passing it to add_objects().
                    bricks[brick_key].add_objects(
                        frame.flux[fibers], frame.ivar[fibers], frame.wave,
                        frame.resolution_data[fibers], brick_data, args.night,
                        exposure)
        # Close all brick files.
        for brick in bricks.itervalues():
            log.debug(
                'Brick %s now contains %d spectra for %d targets.' %
                (brick.path, brick.get_num_spectra(), brick.get_num_targets()))
            brick.close()

    except RuntimeError as e:
        log.critical(str(e))
        return -2
Ejemplo n.º 51
0
def main() :
    """ finds the best models of all standard stars in the frame
    and normlize the model flux. Output is written to a file and will be called for calibration.
    """

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--fiberflatexpid', type = int, help = 'fiberflat exposure ID')
    parser.add_argument('--fibermap', type = str, help = 'path of fibermap file')
    parser.add_argument('--models', type = str, help = 'path of spectro-photometric stellar spectra fits')
    parser.add_argument('--spectrograph', type = int, default = 0, help = 'spectrograph number, can go 0-9')
    parser.add_argument('--outfile', type = str, help = 'output file for normalized stdstar model flux')

    args = parser.parse_args()
    log = get_logger()
    # Call necessary environment variables. No need if add argument to give full file path.
    if 'DESI_SPECTRO_REDUX' not in os.environ:
        raise RuntimeError('Set environment DESI_SPECTRO_REDUX. It is needed to read the needed datafiles')

    DESI_SPECTRO_REDUX=os.environ['DESI_SPECTRO_REDUX']
    PRODNAME=os.environ['PRODNAME']
    if 'DESISIM' not in os.environ:
        raise RuntimeError('Set environment DESISIM. It will be neede to read the filter transmission files for calibration')

    DESISIM=os.environ['DESISIM']   # to read the filter transmission files

    if args.fibermap is None or args.models is None or \
       args.spectrograph is None or args.outfile is None or \
       args.fiberflatexpid is None:
        log.critical('Missing a required argument')
        parser.print_help()
        sys.exit(12)

    # read Standard Stars from the fibermap file
    # returns the Fiber id, filter names and mags for the standard stars

    fiber_tbdata,fiber_header=io.read_fibermap(args.fibermap, header=True)

    #- Trim to just fibers on this spectrograph
    ii =  (500*args.spectrograph <= fiber_tbdata["FIBER"])
    ii &= (fiber_tbdata["FIBER"] < 500*(args.spectrograph+1))
    fiber_tbdata = fiber_tbdata[ii]

    #- Get info for the standard stars
    refStarIdx=np.where(fiber_tbdata["OBJTYPE"]=="STD")
    refFibers=fiber_tbdata["FIBER"][refStarIdx]
    refFilters=fiber_tbdata["FILTER"][refStarIdx]
    refMags=fiber_tbdata["MAG"]

    fibers={"FIBER":refFibers,"FILTER":refFilters,"MAG":refMags}

    NIGHT=fiber_header['NIGHT']
    EXPID=fiber_header['EXPID']
    filters=fibers["FILTER"]
    if 'DESISIM' not in os.environ:
        raise RuntimeError('Set environment DESISIM. Can not find filter response files')
    basepath=DESISIM+"/data/"

    #now load all the skyfiles, framefiles, fiberflatfiles etc
    # all three channels files are simultaneously treated for model fitting
    skyfile={}
    framefile={}
    fiberflatfile={}
    for i in ["b","r","z"]:
        camera = i+str(args.spectrograph)
        skyfile[i] = io.findfile('sky', NIGHT, EXPID, camera)
        framefile[i] = io.findfile('frame', NIGHT, EXPID, camera)
        fiberflatfile[i] = io.findfile('fiberflat', NIGHT, args.fiberflatexpid, camera)

    #Read Frames, Flats and Sky files
    frameFlux={}
    frameIvar={}
    frameWave={}
    frameResolution={}
    framehdr={}
    fiberFlat={}
    ivarFlat={}
    maskFlat={}
    meanspecFlat={}
    waveFlat={}
    headerFlat={}
    sky={}
    skyivar={}
    skymask={}
    skywave={}
    skyhdr={}

    for i in ["b","r","z"]:
       #arg=(night,expid,'%s%s'%(i,spectrograph))
       #- minimal code change for refactored I/O, while not taking advantage of simplified structure
       frame = io.read_frame(framefile[i])
       frameFlux[i] = frame.flux
       frameIvar[i] = frame.ivar
       frameWave[i] = frame.wave
       frameResolution[i] = frame.resolution_data
       framehdr[i] = frame.header

       ff = io.read_fiberflat(fiberflatfile[i])
       fiberFlat[i] = ff.fiberflat
       ivarFlat[i] = ff.ivar
       maskFlat[i] = ff.mask
       meanspecFlat[i] = ff.meanspec
       waveFlat[i] = ff.wave
       headerFlat[i] = ff.header

       skymodel = io.read_sky(skyfile[i])
       sky[i] = skymodel.flux
       skyivar[i] = skymodel.ivar
       skymask[i] = skymodel.mask
       skywave[i] = skymodel.wave
       skyhdr[i] = skymodel.header

    # Convolve Sky with Detector Resolution, so as to subtract from data. Convolve for all 500 specs. Subtracting sky this way should be equivalent to sky_subtract

    convolvedsky={"b":sky["b"], "r":sky["r"], "z":sky["z"]}

    # Read the standard Star data and divide by flat and subtract sky

    stars=[]
    ivars=[]
    for i in fibers["FIBER"]:
        #flat and sky should have same wavelength binning as data, otherwise should be rebinned.

        stars.append((i,{"b":[frameFlux["b"][i]/fiberFlat["b"][i]-convolvedsky["b"][i],frameWave["b"]],
                         "r":[frameFlux["r"][i]/fiberFlat["r"][i]-convolvedsky["r"][i],frameWave["r"]],
                         "z":[frameFlux["z"][i]/fiberFlat["z"][i]-convolvedsky["z"][i],frameWave]},fibers["MAG"][i]))
        ivars.append((i,{"b":[frameIvar["b"][i]],"r":[frameIvar["r"][i,:]],"z":[frameIvar["z"][i,:]]}))


    stdwave,stdflux,templateid=io.read_stdstar_templates(args.models)

    #- Trim standard star wavelengths to just the range we need
    minwave = min([min(w) for w in frameWave.values()])
    maxwave = max([max(w) for w in frameWave.values()])
    ii = (minwave-10 < stdwave) & (stdwave < maxwave+10)
    stdwave = stdwave[ii]
    stdflux = stdflux[:, ii]

    log.info('Number of Standard Stars in this frame: {0:d}'.format(len(stars)))
    if len(stars) == 0:
        log.critical("No standard stars!  Exiting")
        sys.exit(1)

    # Now for each star, find the best model and normalize.

    normflux=[]
    bestModelIndex=np.arange(len(stars))
    templateID=np.arange(len(stars))
    chi2dof=np.zeros(len(stars))

    #- TODO: don't use 'l' as a variable name.  Can look like a '1'
    for k,l in enumerate(stars):
        log.info("checking best model for star {0}".format(l[0]))

        starindex=l[0]
        mags=l[2]
        filters=fibers["FILTER"][k]
        rflux=stars[k][1]["r"][0]
        bflux=stars[k][1]["b"][0]
        zflux=stars[k][1]["z"][0]
        flux={"b":bflux,"r":rflux,"z":zflux}

        #print ivars
        rivar=ivars[k][1]["r"][0]
        bivar=ivars[k][1]["b"][0]
        zivar=ivars[k][1]["z"][0]
        ivar={"b":bivar,"r":rivar,"z":zivar}

        resol_star={"r":frameResolution["r"][l[0]],"b":frameResolution["b"][l[0]],"z":frameResolution["z"][l[0]]}

        # Now find the best Model

        bestModelIndex[k],bestmodelWave,bestModelFlux,chi2dof[k]=match_templates(frameWave,flux,ivar,resol_star,stdwave,stdflux)

        log.info('Star Fiber: {0}; Best Model Fiber: {1}; TemplateID: {2}; Chisq/dof: {3}'.format(l[0],bestModelIndex[k],templateid[bestModelIndex[k]],chi2dof[k]))
        # Normalize the best model using reported magnitude
        modelwave,normalizedflux=normalize_templates(stdwave,stdflux[bestModelIndex[k]],mags,filters,basepath)
        normflux.append(normalizedflux)

    # Now write the normalized flux for all best models to a file
    normflux=np.array(normflux)
    stdfibers=fibers["FIBER"]
    data={}
    data['BESTMODEL']=bestModelIndex
    data['CHI2DOF']=chi2dof
    data['TEMPLATEID']=templateid[bestModelIndex]
    norm_model_file=args.outfile
    io.write_stdstar_model(norm_model_file,normflux,stdwave,stdfibers,data)
Ejemplo n.º 52
0
def plot_graph(frame,
               fibers,
               opt_err=False,
               opt_2d=False,
               label=None,
               subplot=None):
    """Plot graph from a given spectra from a fits file and returns figure
    
    ----------
    Parameters
    ----------

    frame : File Directory
    Where the spectra is collected to be plot.

    fibers : fibers to show
    """

    log = get_logger()
    spectra = frame["FLUX"].data
    ivar = frame["IVAR"].data
    wave = frame["WAVELENGTH"].data
    nfibers = spectra.shape[0]

    if np.max(fibers) >= nfibers:
        log.warning(
            "requested fiber numbers %s exceed number of fibers in file %d" %
            (str(fibers), nfibers))
        fibers = fibers[fibers < nfibers]

    if subplot is None:
        subplot = plt.subplot(1, 1, 1)

    for fiber in fibers:

        if label:
            fiber_label = "%s Fiber #%d" % (label, fiber)
        else:
            fiber_label = "Fiber #%d" % fiber

        log.debug("Plotting fiber %03d" % fiber)
        if opt_err:
            err = np.sqrt(1. / (ivar + (ivar == 0))) * (ivar > 0)
            if len(wave.shape) > 1:
                subplot.errorbar(wave[fiber],
                                 spectra[fiber],
                                 err[fiber],
                                 fmt="o-",
                                 label=fiber_label)
            else:
                subplot.errorbar(wave,
                                 spectra[fiber],
                                 err[fiber],
                                 fmt="o-",
                                 label=fiber_label)
        else:
            if len(wave.shape) > 1:
                subplot.plot(wave[fiber],
                             spectra[fiber],
                             "-",
                             label=fiber_label)
            else:
                subplot.plot(wave, spectra[fiber], "-", label=fiber_label)

    subplot.set_xlabel("Wavelength [A]")

    if opt_2d:
        title = "spectra"
        if label is not None:
            title = label
        plt.figure(title)
        if len(wave.shape) == 1:
            plt.imshow(spectra[fibers].T,
                       aspect='auto',
                       extent=(fibers[0] - 0.5, fibers[-1] + 0.5, wave[0],
                               wave[-1]),
                       origin=0.,
                       interpolation="nearest")
            plt.ylabel("Wavelength [A]")
            plt.xlabel("Fiber #")
        else:
            plt.imshow(spectra[fibers].T,
                       aspect='auto',
                       extent=(fibers[0] - 0.5, fibers[-1] + 0.5, 0,
                               spectra.shape[1]),
                       origin=0.,
                       interpolation="nearest")
            plt.ylabel("Y CCD")
            plt.xlabel("Fiber #")
        plt.colorbar()
Ejemplo n.º 53
0
""" Class to organize QA for a full DESI production run
"""

from __future__ import print_function, absolute_import, division, unicode_literals

import numpy as np
import glob, os

from desispec.io import get_exposures
from desispec.io import get_files
from desispec.io import read_frame

from desispec.log import get_logger

log = get_logger()


class QA_Prod(object):
    def __init__(self, specprod_dir):
        """ Class to organize and execute QA for a DESI production

        Args:
            specprod_dir(str): Path containing the exposures/ directory to use. If the value
                is None, then the value of :func:`specprod_root` is used instead.
        Notes:

        Attributes:
            qa_exps : list
              List of QA_Exposure classes, one per exposure in production
            data : dict
        """
Ejemplo n.º 54
0
def main() :
    """
    parsezbest.py computes common metrics and makes plots for analyzing results of zdc1 redshift challenge (zbest file).

    Metrics:
    + dz = (zbest-ztrue)/(1+ztrue)
    + dv = c*dz
    + pull = (zbest-ztrue)/zerr
    + precision: sigma_z = std(dz), sigma_v = std(dv), nmad_z, nmad_v
    + accuracy (bias): mu_z = mean(dz), mu_v = mean(dv)
    + efficiency
    + purity
    + % of catastrophic failures
    + FOM = purity*efficiency

    Results are stored in outfile (default parezbest_results.dat)

    Plots:
    + Histograms dz, dv, pull
    + dz as a function of zt and zb for:
        - zwarn=0
        - zwarn=0 without catastrophic failures
        - zwarn!=0
    + dz as a function of average S/N per wavelegnth bin for:
        - zwarn=0                                                                                                                                                                                     
        - zwarn=0 without catastrophic failures
        - zwarn!=0      
    
    Color code:
    + zwarn=0: blue filled circles
    + zwarn !=0: red filled circles
    + catastrophic failures: green filled circles

    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,prog='parsezbest.py', usage='%(prog)s [options]\n\n parsezbest.py without options runs a demo')

    parser.add_argument('--b', type = str, default = None, required=False,
                        help = 'path of DESI brick in b')
    parser.add_argument('--r', type = str, default = None, required=False,
                        help = 'path of DESI brick in r')
    parser.add_argument('--z', type = str, default = None, required=False,
                        help = 'path of DESI brick in z')
    parser.add_argument('--outfile', type = str, default = "parsezbest_results.dat", required=False,
                        help = 'path of output file')
    parser.add_argument('--pathtruth', type = str, default = None, required=False,
                        help = 'path of truth table if does not exist in bricks')
    parser.add_argument('--zbest', type = str, default = None, required=False,
                        help = 'zbest file')
    parser.add_argument('--plot', dest='plots', action='store_true', help = 'Plots are optional. They are put to the screen if --plot is used.')
    parser.set_defaults(plots=False)

    args = parser.parse_args()
    log=get_logger()

    file=open(args.outfile,"w")                                                                                                                                                                        

    log.info("starting")

    #- if no arguments is passed, parsezbest runs a demo
    if ((args.zbest is None) and (args.b is None) and (args.r is None) and (args.z is None)):
        args.zbest = "%s/data/zbest-training-elg-100-zztop.fits"%(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(desibest.__file__)))))
        args.b = "%s/data/brick-b-elg-100-zztop.fits"%(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(desibest.__file__)))))
        args.r = "%s/data/brick-r-elg-100-zztop.fits"%(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(desibest.__file__)))))
        args.z = "%s/data/brick-z-elg-100-zztop.fits"%(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(desibest.__file__)))))
    elif ((args.zbest is None) or (args.b is None) or (args.r is None) or (args.z is None)):
            log.error("Either all files (b, r and z bricks and zbest) or none should be provided")
            sys.exit(12)

    try:
        b_brick=Brick(args.b)
    except:
        log.error("can not open brick %s"%args.b)
        print sys.exc_info()
        file.close()
        sys.exit(12)
    try:
        r_brick=Brick(args.r)
    except:
        log.error("can not open brick %s"%args.r)
        print sys.exc_info()
        file.close()
        sys.exit(12)
    try:
        z_brick=Brick(args.z)
    except:
        log.error("can not open brick %s"%args.z)
        print sys.exc_info()
        file.close()
        sys.exit(12)

    try :
        zb_hdulist=fits.open(args.zbest)
    except :
        log.error("can not open file %s:"%args.zbest)
        print sys.exc_info()
        file.close()
        sys.exit(12)

    log.info("Using zbest file %s"%args.zbest)
    log.info("Using %s b brick"%args.b)
    log.info("Using %s r brick"%args.r)
    log.info("Using %s z brick"%args.z)
    log.info(" ")

    file.write("Using zbest file %s\n"%args.zbest)
    file.write("Using %s b brick\n"%args.b)
    file.write("Using %s r brick\n"%args.r)
    file.write("Using %s z brick\n"%args.z)
    file.write("\n")

    brickname = args.b
    
    c=3.e5 # light celerity in vacuum [km/s]

#- Checking zbest structure and format

    log.info("Checking zbest structure and format")
    log.info("-----------------------------------")
    if(zb_hdulist[0].size == 0 and zb_hdulist[1].size != 0):
        log.info("zbest file has the required structure")
        if (zb_hdulist[1].name != 'ZBEST'):
            log.warning("HDU extension 1 has name %s, not ZBEST"%zb_hdulist[1].name)
        log.info(" ")
    else:
        log.error("zbest file does not have the required structure")
        log.error("Check ZBEST structure at http://desidatamodel.readthedocs.org/en/latest/DESI_SPECTRO_REDUX/PRODNAME/bricks/BRICKNAME/zbest-BRICKNAME.html")
        file.close()
        print sys.exc_info()
        sys.exit(12)

    zb_name = zb_hdulist[1].name

    keys=['TARGETID', 'Z', 'ZERR', 'ZWARN', 'TYPE']
    subkeys = ['BRICKNAME', 'SUBTYPE']
    for k in keys:
        try :
            zb_hdulist[zb_name].columns.names.index(k)
        except :
            log.error("Missing column %s in %s"%(k,args.zbest))
            log.error("Check ZBEST format at http://desidatamodel.readthedocs.org/en/latest/DESI_SPECTRO_REDUX/PRODNAME/bricks/BRICKNAME/zbest-BRICKNAME.html")
            file.close()
            print sys.exc_info()
            sys.exit(12)
        else:
            log.info("ZBEST hdu has required column %s"%k)
    for sk in subkeys:
        try:
            zb_hdulist[zb_name].columns.names.index(k)
            log.info("ZBEST hdu has optional column %s"%sk)
        except:
            log.warning("ZBEST hdu misses optional %s column"%sk)
    log.info(" ")

#- checking for truth table

    log.info("Checking for truth table")
    log.info("-----------------------------------")
    try:
        b_hdulist = fits.open(args.b)
        truth_table_hdu=b_hdulist['_TRUTH']
        log.info("brick has truth table")
        log.info(" ")
    except KeyError :
        truth_table_hdu=None
    
    if (truth_table_hdu is None):
        try:
            truthfile = fits.open(args.pathtruth)
            truth_table_hdu = truthfile['_TRUTH']
            log.info("Found truth table %s"%args.pathtruth)
            log.info(" ")
        except:
            log.error("A truth table should be provided")
            file.close()
            sys.exit(12)

#- Get results from zbest hdu and infos from truth table
    truth = truth_table_hdu.data
    zbres=zb_hdulist[zb_name].data

    zb = zbres['Z']
    zt = truth['TRUEZ']
    zw = zbres['ZWARN']
    
#- joining zbest and truth tables  

    # checks that targetids have the same dtype in truth and zbres
    if (zbres['TARGETID'].dtype != truth['TARGETID'].dtype):
        zbres = np.asarray(zbres,dtype=[('BRICKNAME', 'S20'), ('TARGETID', '>i8'), ('Z', '>f8'), ('ZERR', '>f8'), ('ZWARN', '>f8'), ('TYPE', 'S20')])

    zb_zt = join(zbres, truth, keys='TARGETID')

    truez=zb_zt['TRUEZ']
    bestz=zb_zt['Z']
    errz=zb_zt['ZERR']
    zwarn = zb_zt['ZWARN']
    n=bestz.size
    if (n == 0):
        log.error("target ids in zbest file are not a subset of target ids in truth table")                                                                                                            
        log.error("did you provide the bricks that correspond to zbest file ?")
        file.close()
        sys.exit(12) 

#- Select objtype

    log.info("Target type")
    log.info("-----------------------------------")

    obj=dict()
    objtypes=['ELG','LRG','QSO','QSO_BAD','STAR']
    totobj = len(zb_zt['Z'])

    for o in objtypes:
        index=np.where(zb_zt['OBJTYPE'] == '%s'%o)[0]
        obj[o]=len(index)
        log.info("%i %s found"%(obj[o],o))
        file.write("%i %s found\n"%(obj[o],o))
        if (obj[o] != 0): 
            log.info(" ")
            file.write("\n")

            #load requirements for target
            req= desi_requirement(o)
            
            tz = np.zeros(len(index))
            bz = np.zeros(len(index))
            zw = np.zeros(len(index))
            dv = np.zeros(len(index))
            ez = np.zeros(len(index))
            if (o == 'ELG'): 
                trfloii = np.zeros(len(index))

            dz=0.
            dv=0.
            pull=0.

            for i,j in zip(index,range(len(index))):
                tz[j]=truez[i]
                bz[j]=bestz[i]
                zw[j]=zwarn[i]
                ez[j]=errz[i]
                if (o == 'ELG'):
                    trfloii[j] = zb_zt["OIIFLUX"][i]
            dv = c*(bz-tz)/(1+tz)
            dz=dv/c

            true_pos = np.where((np.abs(dz)<0.0033) & (zw==0))[0]
            true_neg = np.where((np.abs(dz)>0.0033) & (zw!=0))[0]
            false_pos = np.where((np.abs(dz)>0.0033) & (zw==0))[0]
            false_neg = np.where((np.abs(dz)<0.0033) & (zw!=0))[0]

            #- total
            total = len(true_pos)+len(true_neg)+len(false_pos)+len(false_neg)

            #- computes sample efficiency
            efficiency = float(len(true_pos))/float(total)
            
            #- computes purity
            purity = float(len(true_pos))/float((len(true_pos)+len(false_pos)))

            #- catastrophic failures
#            cata_fail = float(len(false_pos))/float(total)
            cata_fail = float(len(false_pos))/float((len(true_pos)+len(false_pos)))

            #- figure of merit
            fom = efficiency*purity

            # precision
            #- sigma
            ok = np.where(zw==0)[0]
            zerr = np.std(dz[ok])
            verr = np.std(dv[ok])
            #- quantiles
            dz_quant = np.percentile(dz[ok], (2.5, 16, 50, 84, 97.5))
            dv_quant = np.percentile(dv[ok], (2.5, 16, 50, 84, 97.5))
            dz_err68 = (dz_quant[3] - dz_quant[1])/2.
            dz_err95 = (dz_quant[4] - dz_quant[0])/2.
            dv_err68 = (dv_quant[3] - dv_quant[1])/2.
            dv_err95 = (dv_quant[4] - dv_quant[0])/2.            

            #accuracy
            #- mu
            zacc = np.mean(dz[ok])
            vacc = np.mean(dv[ok])
            #- quantile
            dz_acc50 = dz_quant[2]
            dv_acc50 = dv_quant[2]


            ok_no_cata = np.where((zw==0) & (np.abs(dz)<0.0033))[0]

            if (len(ok) != len(ok_no_cata)):
                zerr_no_cata = np.std(dz[ok_no_cata])
                verr_no_cata = np.std(dv[ok_no_cata])
                zacc_no_cata = np.mean(dz[ok_no_cata])
                vacc_no_cata = np.mean(dv[ok_no_cata])
                dz_quant_no_cata = np.percentile(dz[ok_no_cata], (2.5, 16, 50, 84, 97.5))
                dv_quant_no_cata = np.percentile(dv[ok_no_cata], (2.5, 16, 50, 84, 97.5))
                dz_err68_no_cata = (dz_quant_no_cata[3] - dz_quant_no_cata[1])/2.
                dz_err95_no_cata = (dz_quant_no_cata[4] - dz_quant_no_cata[0])/2.
                dv_err68_no_cata = (dv_quant_no_cata[3] - dv_quant_no_cata[1])/2.
                dv_err95_no_cata = (dv_quant_no_cata[4] - dv_quant_no_cata[0])/2.
                dz_acc50_no_cata = dz_quant_no_cata[2]
                dv_acc50_no_cata = dv_quant_no_cata[2]

            # NMAD
            nmad_z = np.median(np.abs(dz[ok]-np.median(dz[ok])))
            nmad_z *= 1.4826
            nmad_v = np.median(np.abs(dv[ok]-np.median(dv[ok])))
            nmad_v *= 1.4826

            #pull
            pull_ok_no_cata = np.where((zw==0) & (np.abs(dz)<0.0033) & (ez>0))[0]
            pull = (bz[pull_ok_no_cata]-tz[pull_ok_no_cata])/ez[pull_ok_no_cata]
            mu_pull = np.mean(pull)
            sigma_pull = np.std(pull)

            # zwarn
            zw0 = len(np.where(zw == 0)[0])
            zw_non0 = len(np.where(zw != 0)[0])
#            assert (zw0 + zw_non0) != len(index), "zw=0 + zw!=0 not equal to number of objects"

            log.info("=====================================")
            log.info("%s: Precision and accuracy (zwarn=0)"%o)
            log.info("=====================================")
            log.info("sigma_z: %f, mu_z: %f"%(zerr,zacc))
            log.info("quantile: dz_err68: %f"%(dz_err68))
            log.info("quantile: dz_err95: %f"%(dz_err95))
            log.info("quantile: dz_acc50: %f"%(dz_acc50))
            log.info("NMAD_z: %f"%nmad_z)
            log.info("sigma_v: %f, mu_v: %f"%(verr,vacc))
            log.info("quantile: dv_err68: %f"%(dv_err68))
            log.info("quantile: dv_err95: %f"%(dv_err95))
            log.info("quantile: dv_acc50: %f"%(dv_acc50))
            log.info("NMAD_v: %f"%nmad_v)
            if req is not None:
                if (dz_err68>req['SIG_Z']):
                    log.info("dz_err68 & dv_err68 do not meet DESI requirements on precision for %s"%o)
                if (dz_acc50>req['BIAS_Z']):
                    log.info("dz_acc50 & dv_acc50 do not meet DESI requirements on bias for %s"%o)
            log.info(" ")
            if (len(ok) != len(ok_no_cata)):
                log.info("=====================================")
                log.info("%s: Precision and accuracy "%o)
                log.info("zwarn=0 without catastrophic failures")
                log.info("=====================================")
                log.info("sigma_z: %f, mu_z: %f"%(zerr_no_cata,zacc_no_cata))
                log.info("quantile: dz_err68: %f"%(dz_err68_no_cata))
                log.info("quantile: dz_err95: %f"%(dz_err95_no_cata))
                log.info("quantile: dz_acc50: %f"%(dz_acc50_no_cata))
                log.info("sigma_v: %f, mu_v: %f"%(verr_no_cata,vacc_no_cata))
                log.info("quantile: dv_err68: %f"%(dv_err68_no_cata))
                log.info("quantile: dv_err95: %f"%(dv_err95_no_cata))
                log.info("quantile: dv_acc50: %f"%(dv_acc50_no_cata))
                if req is not None:
                    if (dz_err68_no_cata>req['SIG_Z']):
                        log.info("dz_err68 & dv_err68 do not meet DESI requirements on precision for %s"%o)
                    if (dz_acc50_no_cata>req['BIAS_Z']):
                        log.info("dz_acc50 & dv_acc50 do not meet DESI requirements on bias for %s"%o)
                log.info(" ")

            file.write("=====================================\n")
            file.write("%s: Precision and accuracy (zwarn=0)\n"%o)
            file.write("=====================================\n")
            file.write("sigma_z: %f, mu_z: %f\n"%(zerr,zacc))
            file.write("quantile: dz_err68: %f\n"%(dz_err68))
            file.write("quantile: dz_err95: %f\n"%(dz_err95))
            file.write("quantile: dz_acc50: %f\n"%(dz_acc50))
            file.write("NMAD_z: %f\n"%nmad_z)
            file.write("sigma_v: %f, mu_v: %f\n"%(verr,vacc))
            file.write("quantile: dv_err68: %f\n"%(dv_err68))
            file.write("quantile: dv_err95: %f\n"%(dv_err95))
            file.write("quantile: dv_acc50: %f\n"%(dv_acc50))
            file.write("NMAD_v: %f\n"%nmad_v)
            if req is not None:
                if (dz_err68>req['SIG_Z']):
                    file.write("dz_err68 & dv_err68 do not meet DESI requirements on precision for %s\n"%o)
                if (dz_acc50>req['BIAS_Z']):
                    file.write("dz_acc50 & dv_acc50 do not meet DESI requirements on bias for %s\n"%o)
            file.write("\n")
            if (len(ok) != len(ok_no_cata)):
                file.write("=====================================\n")
                file.write("%s: Precision and accuracy \n"%o)
                file.write("zwarn=0 without catastrophic failures\n")
                file.write("=====================================\n")
                file.write("sigma_z: %f, mu_z: %f\n"%(zerr_no_cata,zacc_no_cata))
                file.write("quantile: dz_err68: %f\n"%(dz_err68_no_cata))
                file.write("quantile: dz_err95: %f\n"%(dz_err95_no_cata))
                file.write("quantile: dz_acc50: %f\n"%(dz_acc50_no_cata))
                file.write("sigma_v: %f, mu_v: %f\n"%(verr_no_cata,vacc_no_cata))
                file.write("quantile: dv_err68: %f\n"%(dv_err68_no_cata))
                file.write("quantile: dv_err95: %f\n"%(dv_err95_no_cata))
                file.write("quantile: dv_acc50: %f\n"%(dv_acc50_no_cata))
                if req is not None:
                    if (dz_err68_no_cata>req['SIG_Z']):
                        file.write("dz_err68 & dv_err68 do not meet DESI requirements on precision for %s\n"%o)
                    if (dz_acc50_no_cata>req['BIAS_Z']):
                        file.write("dz_acc50 & dv_acc50 do not meet DESI requirements on bias for %s\n"%o)
                file.write("\n")



            log.info("=====================================")
            log.info("%s: Pull (zwarn=0, no cata. fail.)"%o)
            log.info("=====================================")
            log.info("mu: %f, sigma: %f"%(mu_pull,sigma_pull))
            log.info(" ")

            file.write("=====================================\n")
            file.write("%s: Pull (zwarn=0, no cata. fail.)\n"%o)
            file.write("=====================================\n")
            file.write("mu: %f, sigma: %f\n"%(mu_pull,sigma_pull))
            file.write("\n")

            if (o == 'ELG'):
                true_pos_oII = np.where((np.abs(dz)<0.0033) & (zw==0) & (trfloii>8e-17))[0]
                true_neg_oII = np.where((np.abs(dz)>0.0033) & (zw!=0) & (trfloii>8e-17))[0]
                false_pos_oII = np.where((np.abs(dz)>0.0033) & (zw==0) & (trfloii>8e-17))[0]
                false_neg_oII = np.where((np.abs(dz)<0.0033) & (zw!=0) & (trfloii>8e-17))[0]

                #- total                                                                              
                total_oII = len(true_pos_oII)+len(true_neg_oII)+len(false_pos_oII)+len(false_neg_oII)
                
                #- computes sample efficiency                                                         
                efficiency_oII = float(len(true_pos_oII))/float(total_oII)
            
                #- computes purity                                                             
                purity_oII = float(len(true_pos_oII))/float((len(true_pos_oII)+len(false_pos_oII)))

                #- catastrophic failures
#                cata_fail_oII = float(len(false_pos_oII))/float(total_oII)
                cata_fail_oII = float(len(false_pos_oII))/float((len(true_pos_oII)+len(false_pos_oII)))

                #- figure of merit
                fom_oII = efficiency_oII*purity_oII
                
                oii8e17 = np.where((np.abs(dz)<0.0033) & (zw==0) & (trfloii>8e-17))[0]
                stdzoii = np.std(dz[oii8e17])
                stdvoii = np.std(dv[oii8e17])
                acczoii = np.mean(dz[oii8e17])
                accvoii = np.mean(dv[oii8e17])
                dz_quant_oii = np.percentile(dz[oii8e17], (2.5, 16, 50, 84, 97.5))
                dv_quant_oii = np.percentile(dv[oii8e17], (2.5, 16, 50, 84, 97.5))
                dz_err68_oii = (dz_quant_oii[3] - dz_quant_oii[1])/2.
                dz_err95_oii = (dz_quant_oii[4] - dz_quant_oii[0])/2.
                dv_err68_oii = (dv_quant_oii[3] - dv_quant_oii[1])/2.
                dv_err95_oii = (dv_quant_oii[4] - dv_quant_oii[0])/2.
                dz_acc50_oii = dz_quant_oii[2]
                dv_acc50_oii = dv_quant_oii[2]


                log.info("=====================================")
                log.info("%s: For OII > 8e-17 erg/s/cm2"%o)
                log.info("=====================================")
                log.info('Efficiency_oII: %d/%d=%f'%(len(true_pos_oII),total_oII,efficiency_oII))
                if req is not None:
                    if (efficiency_oII < req['EFFICIENCY']):
                        log.info("Efficiency_oII does not meet DESI requirements for %s"%o)
                log.info('Purity_oII: %d/%d=%f'%(len(true_pos_oII),(len(true_pos_oII)+len(false_pos_oII)),purity_oII))
                log.info('Catastrophic failures_oII: %d/%d=%f'%(len(false_pos_oII),len(true_pos_oII)+len(false_pos_oII),cata_fail_oII))
                if req is not None:
                    if (cata_fail_oII>req['CATA_FAIL_MAX']):
                        log.info("Catastrophic failure rate does not meet DESI requirements for %s"%o)
                log.info('sigma_z_oii: %f, mu_z_oii: %f'%(stdzoii, acczoii))
                log.info('quantile: dz_err68_oii: %f'%(dz_err68_oii))
                log.info('quantile: dz_err95_oii: %f'%(dz_err95_oii))
                log.info('quantile: dz_acc50_oii: %f'%(dz_acc50_oii))
                log.info('sigma_v_oii: %f, mu_v_oii: %f'%(stdvoii, accvoii))
                log.info('quantile: dv_err68_oii: %f'%(dv_err68_oii))
                log.info('quantile: dv_err95_oii: %f'%(dv_err95_oii))
                log.info('quantile: dv_acc50_oii: %f'%(dv_acc50_oii))
                log.info('FOM_oII: %f x %f=%f'%(efficiency_oII,purity_oII,fom_oII))
                log.info(" ")

                file.write("=====================================\n")
                file.write("%s: For OII > 8e-17 erg/s/cm2\n"%o)
                file.write("=====================================\n")
                file.write('Efficiency_oII: %d/%d=%f\n'%(len(true_pos_oII),total_oII,efficiency_oII))
                if req is not None:
                    if (efficiency_oII < req['EFFICIENCY']):
                        file.write("Efficiency_oII does not meet DESI requirements for %s\n"%o)
                file.write('Purity_oII: %d/%d=%f\n'%(len(true_pos_oII),(len(true_pos_oII)+len(false_pos_oII)),purity_oII))
                file.write('Catastrophic failures_oII: %d/%d=%f\n'%(len(false_pos_oII),len(true_pos_oII)+len(false_pos_oII),cata_fail_oII))
                if req is not None:
                    if (cata_fail_oII>req['CATA_FAIL_MAX']):
                        file.write("Catastrophic failure rate does not meet DESI requirements for %s\n"%o)
                file.write('sigma_z_oii: %f, mu_z_oii: %f\n'%(stdzoii, acczoii))
                file.write('quantile: dz_err68_oii: %f\n'%(dz_err68_oii))
                file.write('quantile: dz_err95_oii: %f\n'%(dz_err95_oii))
                file.write('quantile: dz_acc50_oii: %f\n'%(dz_acc50_oii)) 
                file.write('sigma_v_oii: %f, mu_v_oii: %f\n'%(stdvoii, accvoii))
                file.write('quantile: dv_err68_oii: %f\n'%(dv_err68_oii))
                file.write('quantile: dv_err95_oii: %f\n'%(dv_err95_oii))
                file.write('quantile: dv_acc50_oii: %f\n'%(dv_acc50_oii))
                file.write('FOM_oII: %f x %f=%f\n'%(efficiency_oII,purity_oII,fom_oII))
                file.write("\n")


            log.info("=====================================")
            log.info("%s: Total sample"%o)
            log.info("=====================================")
            log.info("zwarn = 0: %d"%zw0)
            log.info("zwarn !=0: %d"%zw_non0)
            log.info('Efficiency: %d/%d=%f'%(len(true_pos),total,efficiency))
            if req is not None:
                if (efficiency < req['EFFICIENCY']):
                    log.info("Efficiency does not meet DESI requirements for %s"%o)
            log.info('Purity: %d/%d=%f'%(len(true_pos),(len(true_pos)+len(false_pos)),purity))
            log.info('Catastrophic failures: %d/%d=%f'%(len(false_pos),len(true_pos)+len(false_pos),cata_fail))
            if req is not None:
                if (cata_fail>req['CATA_FAIL_MAX']):
                    log.info("Catastrophic failure rate does not meet DESI requirements for %s"%o)
            log.info('FOM: %f x %f=%f'%(efficiency,purity,fom))
            log.info("=====================================")
            log.info(" ")

            file.write("=====================================\n")
            file.write("%s: Total sample\n"%o)
            file.write("=====================================\n")
            file.write("zwarn = 0: %d\n"%zw0)
            file.write("zwarn !=0: %d\n"%zw_non0)
            file.write('Efficiency: %d/%d=%f\n'%(len(true_pos),total,efficiency))
            if req is not None:
                if (efficiency < req['EFFICIENCY']):
                    file.write("Efficiency does not meet DESI requirements for %s\n"%o)
            file.write('Purity: %d/%d=%f\n'%(len(true_pos),(len(true_pos)+len(false_pos)),purity))
            file.write('Catastrophic failures: %d/%d=%f\n'%(len(false_pos),len(true_pos)+len(false_pos),cata_fail))
            if req is not None:
                if (cata_fail>req['CATA_FAIL_MAX']):
                    file.write("Catastrophic failure rate does not meet DESI requirements for %s\n"%o)
            file.write('FOM: %f x %f=%f\n'%(efficiency,purity,fom))
            file.write("=====================================\n")
            file.write("\n")


            # computes spectrum S/N                                                                                                                                                                     
            mean_ston=np.zeros(len(index))
            mean_ston_oII=np.zeros(len(index))
            for spec,sp in zip(index,range(len(index))):
                flux=[b_brick.hdu_list[0].data[spec],r_brick.hdu_list[0].data[spec],z_brick.hdu_list[0].data[spec]]
                ivar=[b_brick.hdu_list[1].data[spec],r_brick.hdu_list[1].data[spec],z_brick.hdu_list[1].data[spec]]
                wave=[b_brick.hdu_list[2].data,r_brick.hdu_list[2].data,z_brick.hdu_list[2].data]
                for i in range(3):
                    mean_ston[sp] += np.sum(np.abs(flux[i])*np.sqrt(ivar[i]))/len(wave[i])

                # computes mean S/N in OII lines for ELG
                if (o == 'ELG'):
                    for i in range(3):
                        ok = np.where((wave[i]>3722) & (wave[i]< 3734))[0]
                        if (len(ok) == 0):
                            break
                        else:
                            mean_ston_oII[sp] += np.sum(np.abs(flux[i][ok])*np.sqrt(ivar[i][ok]))/len(ok)
                
            #- plots

            if (args.plots):
                ok=np.where(zw==0)[0]
                cata = np.where((zw == 0) & (np.abs(dz)>0.0033))[0]
                not_ok = np.where(zw !=0)[0]
#            ok_no_cata = np.where((zw == 0) & (np.abs(dz)<0.0033))[0]

            #- histograms
            
                pylab.figure()
                n, bins, patches = pylab.hist(dz[ok_no_cata], 30, normed=1, histtype='stepfilled')
                pylab.setp(patches, 'facecolor', 'b', 'alpha', 0.75)
                if (o != 'QSO'):
                    muz = np.mean(dz[ok_no_cata])
                    sigmaz = np.std(dz[ok_no_cata])
                    gauss = pylab.normpdf(bins, muz, sigmaz)
                    l = pylab.plot(bins, gauss, 'k--', linewidth=1.5, label="mu=%2.0f *1e-6, sig=%2.0f *1e-6"%(muz/1e-6,sigmaz/1e-6))
                    pylab.legend()
                pylab.xlabel("(zb-zt)/(1+zt) (ZWARN=0 without catastrophic failures)")
                pylab.ylabel("Num. of %s targets per bin"%o)
            

                pylab.figure()
                n, bins, patches = pylab.hist(dv[ok_no_cata], 30, normed=1, histtype='stepfilled')
                pylab.setp(patches, 'facecolor', 'g', 'alpha', 0.75)
                if (o != 'QSO'): 
                    muv = np.mean(dv[ok_no_cata])
                    sigmav = np.std(dv[ok_no_cata])
                    gauss = pylab.normpdf(bins, muv, sigmav)
                    l = pylab.plot(bins, gauss, 'k--', linewidth=1.5, label="mu=%2.0f, sig=%2.0f"%(muv,sigmav))
                    pylab.legend()
                pylab.xlabel("Delta v = c(zb-zt)/(1+zt) [km/s] (ZWARN=0 without catastrophic failures)")
                pylab.ylabel("Num. of %s targets per bin"%o)

            #- pull distribution

                pylab.figure()
                n, bins, patches = pylab.hist(pull, 30, normed=1, histtype='stepfilled')
                pylab.setp(patches, 'facecolor', 'c', 'alpha', 0.75)
                #            mu_pull = np.mean(pull[ok_no_cata])
                #            sigma_pull = np.std(pull[ok_no_cata])
                gauss = pylab.normpdf(bins, mu_pull, sigma_pull)
                l = pylab.plot(bins, gauss, 'k--', linewidth=1.5, label="mu=%2.3f, sig=%2.3f"%(mu_pull,sigma_pull))
                pylab.legend()
                mu=0.
                sig=1.
                gauss1 = pylab.normpdf(bins, mu, sig)
                l1 = pylab.plot(bins, gauss1, 'k--', linewidth=1.5, color='r', label="mu=0., sig=1.")
                pylab.legend()
                pylab.xlabel("Pull = (zb-<zt>)/zerr (ZWARN=0 without catastrophic failures)")
                pylab.ylabel("Num. of %s targets per bin"%o)

            #- other plots

                pylab.figure()
                nx = 1
                if len(cata) !=0:
                    ny = 3
                else:
                    ny = 2
                ai = 1

            # catastrophic failures in green
            # zwarn != 0 in red
            # zw =0 no catastrophic in blue

                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(mean_ston[ok],dz[ok],errz[ok],fmt="bo")
                a.errorbar(mean_ston[cata],dz[cata],errz[cata],fmt="go")
                a.set_xlabel("%s <S/N>"%o)
                a.set_ylabel("(zb-zt)/(1+zt) (ZWARN=0)")

                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(mean_ston[ok],dz[ok],errz[ok],fmt="bo")
                a.errorbar(mean_ston[not_ok],dz[not_ok],errz[not_ok],fmt="ro")
                a.errorbar(mean_ston[cata],dz[cata],errz[cata],fmt="go")
                a.set_xlabel("%s <S/N> "%o)
                a.set_ylabel("(zb-zt)/(1+zt) (all ZWARN)")

                if len(cata) !=0:
                    a=pylab.subplot(ny,nx,ai); ai +=1
                    a.errorbar(mean_ston[ok_no_cata],dz[ok_no_cata],errz[ok_no_cata],fmt="bo")
                    a.set_xlabel("%s <S/N> "%o)
                    a.set_ylabel("(zb-zt)/(1+zt) (ZWARN=0, no cata. fail.)")


                if (o == 'ELG'):
                    pylab.figure()
                    nx=1
                    if len(cata) !=0:
                        ny=3
                    else:
                        ny=2
                    ai=1

                    a=pylab.subplot(ny,nx,ai); ai +=1
                    a.errorbar(zb_zt['OIIFLUX'][ok],dz[ok],errz[ok],fmt="bo")
                    a.errorbar(zb_zt['OIIFLUX'][cata],dz[cata],errz[cata],fmt="ro")
                    a.set_xlabel("%s True [OII] flux"%o)
                    a.set_ylabel("(zb-zt)/(1+zt) (all ZWARN)")

                    a=pylab.subplot(ny,nx,ai); ai +=1
                    a.errorbar(zb_zt['OIIFLUX'][ok],dz[ok],errz[ok],fmt="bo")
                    a.errorbar(zb_zt['OIIFLUX'][not_ok],dz[not_ok],errz[not_ok],fmt="ro")
                    a.errorbar(zb_zt['OIIFLUX'][cata],dz[cata],errz[cata],fmt="go")
                    a.set_xlabel("%s True [OII] flux"%o)
                    a.set_ylabel("(zb-zt)/(1+zt) (ZWARN=0)")
                    
                    if len(cata) != 0:
                        a=pylab.subplot(ny,nx,ai); ai +=1
                        a.errorbar(zb_zt['OIIFLUX'][ok_no_cata],dz[ok_no_cata],errz[ok_no_cata],fmt="bo")
                        a.set_xlabel("%s True [OII] flux"%o)
                        a.set_ylabel("(zb-zt)/(1+zt) (ZWARN=0, no cata. fail.)")

                pylab.figure()
                nx=2
                if len(cata) !=0:
                    ny=3
                else:
                    ny=2
                ai=1
    
                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(tz[ok],dz[ok],errz[ok],fmt="o",c="b")
                a.errorbar(tz[not_ok],dz[not_ok],errz[not_ok],fmt="o",c="r")
                a.errorbar(tz[cata],dz[cata],errz[cata],fmt="o",c="g")
                a.set_xlabel("%s zt (all ZWARN)"%o)
                a.set_ylabel("(zb-zt)/(1+zt) (all ZWARN)")
            
                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(bz[ok],dz[ok],errz[ok],fmt="o",c="b")
                a.errorbar(bz[not_ok],dz[not_ok],errz[not_ok],fmt="o",c="r")
                a.errorbar(bz[cata],dz[cata],errz[cata],fmt="o",c="g")
                a.set_xlabel("%s zb (all ZWARN)"%o)
                a.set_ylabel("(zb-zt)/(1+zt)")
            
                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(tz[ok],dz[ok],errz[ok],fmt="o",c="b")
                a.errorbar(tz[cata],dz[cata],errz[cata],fmt="o",c="g")
                a.set_xlabel("%s zt (ZWARN=0)"%o)
                a.set_ylabel("(zb-zt)/(1+zt)")

                a=pylab.subplot(ny,nx,ai); ai +=1
                a.errorbar(bz[ok],dz[ok],errz[ok],fmt="o",c="b")
                a.errorbar(bz[cata],dz[cata],errz[cata],fmt="o",c="g")
                a.set_xlabel("%s zb (ZWARN=0)"%o)
                a.set_ylabel("(zb-zt)/(1+zt)")

                if len(cata) !=0:
                    a=pylab.subplot(ny,nx,ai); ai +=1
                    a.errorbar(tz[ok_no_cata],dz[ok_no_cata],errz[ok_no_cata],fmt="o",c="b")
                    a.set_xlabel("%s zt (ZWARN=0, no cata. fail.)"%o)
                    a.set_ylabel("(zb-zt)/(1+zt)")
                    
                    a=pylab.subplot(ny,nx,ai); ai +=1
                    a.errorbar(bz[ok_no_cata],dz[ok_no_cata],errz[ok_no_cata],fmt="o",c="b")
                    a.set_xlabel("%s zb (All ZWARN, no cata. fail.)"%o)
                    a.set_ylabel("(zb-zt)/(1+zt)")


                pylab.show()                
Ejemplo n.º 55
0
def main(args):

    if args.verbose:
        log = get_logger(DEBUG)
    else:
        log = get_logger()

    if args.brick is None:
        log.critical('Missing required brick argument.')
        return -1

    # Open the combined coadd file for this brick, for updating.
    coadd_all_path = desispec.io.meta.findfile('coadd_all',
                                               brickname=args.brick,
                                               specprod_dir=args.specprod)
    coadd_all_file = desispec.io.brick.CoAddedBrick(coadd_all_path,
                                                    mode='update')

    # Initialize dictionaries of co-added spectra for each object ID.
    coadded_spectra = {}

    # Keep track of the index we assign to each target.
    next_coadd_index = 0
    target_index = {}

    # The HDU4 table for the global coadd will go here.
    coadd_all_info = None

    # Loop over bands for this brick.
    for band in args.bands:
        # Open this band's brick file for reading.
        brick_path = desispec.io.meta.findfile('brick',
                                               brickname=args.brick,
                                               band=band,
                                               specprod_dir=args.specprod)
        if not os.path.exists(brick_path):
            log.info(
                'Skipping non-existent brick file {0}.'.format(brick_path))
            continue
        brick_file = desispec.io.brick.Brick(brick_path, mode='readonly')
        flux_in, ivar_in, wlen, resolution_in = (brick_file.hdu_list[0].data,
                                                 brick_file.hdu_list[1].data,
                                                 brick_file.hdu_list[2].data,
                                                 brick_file.hdu_list[3].data)
        log.debug('Processing %s with %d exposures of %d targets...' %
                  (brick_path, brick_file.get_num_spectra(),
                   brick_file.get_num_targets()))
        if resolution_in.shape[1] != desispec.resolution.num_diagonals:
            log.error(
                'resolution has unexpected shape (ndiag=%d != %d). Skipping this file.'
                % (resolution_in.shape[1], desispec.resolution.num_diagonals))
            brick_file.close()
            continue
        # Open this band's coadd file for updating.
        coadd_path = desispec.io.meta.findfile('coadd',
                                               brickname=args.brick,
                                               band=band,
                                               specprod_dir=args.specprod)
        coadd_file = desispec.io.brick.CoAddedBrick(coadd_path, mode='update')
        # Copy the input fibermap info for each exposure into memory.
        coadd_info = np.copy(brick_file.hdu_list[4].data)
        # Also copy the first band's info to initialize the global coadd info, but remember that this
        # band might not have all targets so we could see new targets in other bands.
        if coadd_all_info is None:
            coadd_all_info = np.copy(brick_file.hdu_list[4].data)

        # Loop over objects in the input brick file.
        for index, info in enumerate(brick_file.hdu_list[4].data):
            assert index == info['INDEX'], 'Index mismatch: %d != %d' % (
                index, info['INDEX'])
            resolution_matrix = desispec.resolution.Resolution(
                resolution_in[index])
            spectrum = desispec.coaddition.Spectrum(wlen, flux_in[index],
                                                    ivar_in[index],
                                                    resolution_matrix)
            target_id = info['TARGETID']
            # Are we only processing specified targets?
            if len(args.target) > 0 and target_id not in args.target:
                continue
            # Have we seen this target before?
            if target_id not in coadded_spectra:
                coadded_spectra[target_id] = {}
                target_index[target_id] = next_coadd_index
                next_coadd_index += 1
            # Save the coadd index to our output table.
            coadd_info['INDEX'][index] = target_index[target_id]
            # Initialize the coadd for this band and target if necessary.
            if band not in coadded_spectra[target_id]:
                coadded_spectra[target_id][
                    band] = desispec.coaddition.Spectrum(wlen)
            # Do the coaddition.
            coadded_spectra[target_id][band] += spectrum

            # Is this exposure of this target already in our global coadd table?
            exposure = info['EXPID']
            seen = (coadd_all_info['EXPID']
                    == exposure) & (coadd_all_info['TARGETID'] == target_id)
            if not np.any(seen):
                log.info(
                    'Adding exposure %d of target %d to global coadd with partial band coverage.'
                    % (exposure, target_id))
                coadd_all_info.append(coadd_info[index])
            else:
                coadd_all_info['INDEX'][index] = target_index[target_id]

        # Allocate arrays for the coadded results for this band. Since we always use the same index
        # for the same target in each band, there might be some unused entries in these arrays if
        # some bands are missing for some targets.
        num_targets = 1 + np.max(coadd_info['INDEX'])
        nbins = len(wlen)
        flux_out = np.zeros((num_targets, nbins))
        ivar_out = np.zeros_like(flux_out)
        resolution_out = np.zeros(
            (num_targets, desispec.resolution.num_diagonals, nbins))

        # Save the coadded spectra for this band.
        for target_id in coadded_spectra:
            if band not in coadded_spectra[target_id]:
                continue
            exposures = (coadd_info['TARGETID'] == target_id)
            index = target_index[target_id]
            log.debug(
                'Saving coadd of %d exposures for target ID %d to index %d.' %
                (np.count_nonzero(exposures), target_id, index))
            coadd = coadded_spectra[target_id][band]
            coadd.finalize()
            flux_out[index] = coadd.flux
            ivar_out[index] = coadd.ivar
            resolution_out[index] = coadd.resolution.to_fits_array()

        # Save the coadds for this band.
        coadd_file.add_objects(flux_out, ivar_out, wlen, resolution_out)
        coadd_file.hdu_list[4].data = coadd_info

        # Close files for this band.
        coadd_file.close()
        brick_file.close()

    # Allocate space for the global coadded results.
    num_targets = next_coadd_index
    nbins = len(desispec.coaddition.global_wavelength_grid)
    flux_all = np.empty((num_targets, nbins))
    ivar_all = np.empty_like(flux_all)
    resolution_all = np.empty(
        (num_targets, desispec.resolution.num_diagonals, nbins))

    # Coadd the bands for each target ID.
    all_bands = ','.join(sorted(args.bands))
    for target_id in coadded_spectra:
        index = target_index[target_id]
        bands = ','.join(sorted(coadded_spectra[target_id].keys()))
        log.debug('Combining %s bands for target %d at index %d.' %
                  (bands, target_id, index))
        if bands != all_bands:
            log.warning('WARNING: target %d has partial band coverage: %s' %
                        (target_id, bands))
        coadd_all = desispec.coaddition.Spectrum(
            desispec.coaddition.global_wavelength_grid)
        for coadd_band in coadded_spectra[target_id].itervalues():
            coadd_all += coadd_band
        coadd_all.finalize()
        flux_all[index] = coadd_all.flux
        ivar_all[index] = coadd_all.ivar
        resolution_all[index] = coadd_all.resolution.to_fits_array()

    # Save the global coadds.
    coadd_all_file.add_objects(flux_all, ivar_all,
                               desispec.coaddition.global_wavelength_grid,
                               resolution_all)
    coadd_all_file.hdu_list[4].data = coadd_all_info

    # Close the combined coadd file.
    coadd_all_file.close()
Ejemplo n.º 56
0
"""
desisim.templates
=================

Functions to simulate spectral templates for DESI.
"""

from __future__ import division, print_function

import os
import sys
import numpy as np

from desispec.log import get_logger
log = get_logger()

class TargetCuts():
    """Select targets from flux cuts.  This is a placeholder class that will be
       refactored into desitarget.  Hence, the documentation here is
       intentionally sparse.

    """
    def __init__(self):
        pass
        
    def BGS(self,rflux=None):
        BGS = rflux > 10**((22.5-19.35)/2.5)
        return BGS

    def ELG(self,gflux=None, rflux=None, zflux=None):
        ELG  = rflux > 10**((22.5-23.4)/2.5)