Beispiel #1
0
def spline_fit(output_wave,input_wave,input_flux,required_resolution,input_ivar=None,order=3,max_resolution=None):
    """Performs spline fit of input_flux vs. input_wave and resamples at output_wave

    Args:
        output_wave : 1D array of output wavelength samples
        input_wave : 1D array of input wavelengths
        input_flux : 1D array of input flux density
        required_resolution (float) : resolution for spline knot placement (same unit as wavelength)

    Options:
        input_ivar : 1D array of weights for input_flux
        order (int) : spline order
        max_resolution (float) : if not None and first fit fails, try once this resolution

    Returns:
        output_flux : 1D array of flux sampled at output_wave
    """
    if input_ivar is not None :
        selection=np.where(input_ivar>0)[0]
        if selection.size < 2 :
            log=get_logger()
            log.error("cannot do spline fit because only {0:d} values with ivar>0".format(selection.size))
            raise ValueError
        w1=input_wave[selection[0]]
        w2=input_wave[selection[-1]]
    else :
        w1=input_wave[0]
        w2=input_wave[-1]

    res=required_resolution
    n=int((w2-w1)/res)
    res=(w2-w1)/(n+1)
    knots=w1+res*(0.5+np.arange(n))

    ## check that nodes are close to pixels
    dknots = abs(knots[:,None]-input_wave)
    mins = np.amin(dknots,axis=1)
    w=mins<res
    knots = knots[w]
    try :
        toto=scipy.interpolate.splrep(input_wave,input_flux,w=input_ivar,k=order,task=-1,t=knots)
        output_flux = scipy.interpolate.splev(output_wave,toto)
    except ValueError as err :
        log=get_logger()
        if max_resolution is not None  and required_resolution < max_resolution :
            log.warning("spline fit failed with resolution={}, retrying with {}".format(required_resolution,max_resolution))
            return spline_fit(output_wave,input_wave,input_flux,max_resolution,input_ivar=input_ivar,order=3,max_resolution=None)
        else :
            log.error("spline fit failed")
            raise ValueError
    return output_flux
Beispiel #2
0
def sim(night, nspec=5, clobber=False):
    """
    Simulate data as part of the integration test.

    Args:
        night (str): YEARMMDD
        nspec (int, optional): number of spectra to include
        clobber (bool, optional): rerun steps even if outputs already exist

    Raises:
        RuntimeError if any script fails
    """
    log = logging.get_logger()

    # Create input fibermaps, spectra, and pixel-level raw data

    for expid, program in zip([0,1,2], ['flat', 'arc', 'dark']):
        cmd = "newexp-random --program {program} --nspec {nspec} --night {night} --expid {expid}".format(
            expid=expid, program=program, nspec=nspec, night=night)
        fibermap = io.findfile('fibermap', night, expid)
        simspec = '{}/simspec-{:08d}.fits'.format(os.path.dirname(fibermap), expid)
        inputs = []
        outputs = [fibermap, simspec]
        if runcmd(cmd, inputs=inputs, outputs=outputs, clobber=clobber) != 0:
            raise RuntimeError('newexp-random failed for {} exposure {}'.format(program, expid))

        cmd = "pixsim --nspec {nspec} --night {night} --expid {expid}".format(expid=expid, nspec=nspec, night=night)
        inputs = [fibermap, simspec]
        outputs = list()
        outputs.append(fibermap.replace('fibermap-', 'simpix-'))
        outputs.append(io.findfile('raw', night, expid))
        if runcmd(cmd, inputs=inputs, outputs=outputs, clobber=clobber) != 0:
            raise RuntimeError('pixsim failed for {} exposure {}'.format(program, expid))

    return
Beispiel #3
0
    def init_fluxcalib(self, re_init=False):
        """ Initialize parameters for FLUXCALIB QA
        Args:
            re_init: bool, (optional)
              Re-initialize  parameter dict

        Returns:

        """
        log=get_logger()
        assert self.flavor == 'science'

        # Standard FLUXCALIB input parameters
        flux_dict = dict(ZP_WAVE=0.,        # Wavelength for ZP evaluation (camera dependent)
                         MAX_ZP_OFF=0.2,    # Max offset in ZP for individual star
                         )

        if self.camera[0] == 'b':
            flux_dict['ZP_WAVE'] = 4800.  # Ang
        elif self.camera[0] == 'r':
            flux_dict['ZP_WAVE'] = 6500.  # Ang
        elif self.camera[0] == 'z':
            flux_dict['ZP_WAVE'] = 8250.  # Ang
        else:
            log.error("Not ready for camera {}!".format(self.camera))

        # Init
        self.init_qatype('FLUXCALIB', flux_dict, re_init=re_init)
def check_env():
    """Check required environment variables.

    Raises:
        RuntimeError if any script fails
    """
    log = get_logger()
    #- template locations
    missing_env = False
    if 'DESI_BASIS_TEMPLATES' not in os.environ:
        log.warning('missing $DESI_BASIS_TEMPLATES needed for simulating spectra')
        missing_env = True

    if not os.path.isdir(os.getenv('DESI_BASIS_TEMPLATES')):
        log.warning('missing $DESI_BASIS_TEMPLATES directory')
        log.warning('e.g. see NERSC:/project/projectdirs/desi/spectro/templates/basis_templates/v2.2')
        missing_env = True

    for name in (
        'DESI_SPECTRO_SIM', 'DESI_SPECTRO_REDUX', 'PIXPROD', 'SPECPROD'):
        if name not in os.environ:
            log.warning("missing ${0}".format(name))
            missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD/")
        log.warning("    Raw data read from $DESI_SPECTRO_DATA/")
        log.warning("    Spectro pipeline output written to $DESI_SPECTRO_REDUX/$SPECPROD/")
        log.warning("    Templates are read from $DESI_BASIS_TEMPLATES")
        log.critical("missing env vars; exiting without running pipeline")
        raise RuntimeError("missing env vars; exiting without running pipeline")

    #- Override $DESI_SPECTRO_DATA to match $DESI_SPECTRO_SIM/$PIXPROD
    os.environ['DESI_SPECTRO_DATA'] = os.path.join(os.getenv('DESI_SPECTRO_SIM'), os.getenv('PIXPROD'))
Beispiel #5
0
    def slurp_nights(self, make_frameqa=False, remove=True, write_nights=False, **kwargs):
        """ Slurp all the individual QA files, night by night
        Loops on nights, generating QANight objects along the way

        Args:
            make_frameqa: bool, optional
              Regenerate the individual QA files (at the frame level first)
            remove: bool, optional
              Remove the individual QA files?

        Returns:

        """
        log = get_logger()
        # Remake?
        if make_frameqa:
            self.make_frameqa(**kwargs)
        # Reset
        log.info("Resetting QA_Night objects")
        self.qa_nights = []
        # Loop on nights
        for night in self.mexp_dict.keys():
            qaNight = QA_Night(night)
            qaNight.slurp(remove=remove)
            # Save nights
            self.qa_nights.append(qaNight)
            # Write?
            if write_nights:
                qaNight.write_qa_exposures()
Beispiel #6
0
def search_for_framefile(frame_file):
    """ Search for an input frame_file in the desispec redux hierarchy
    Args:
        frame_file:  str

    Returns:
        mfile: str,  full path to frame_file if found else raise error

    """
    log=get_logger()
    # Parse frame file
    path, ifile = os.path.split(frame_file)
    splits = ifile.split('-')
    root = splits[0]
    camera = splits[1]
    fexposure = int(splits[2].split('.')[0])

    # Loop on nights
    nights = get_nights()
    for night in nights:
        for exposure in get_exposures(night):
            if exposure == fexposure:
                mfile = findfile(root, camera=camera, night=night, expid=exposure)
                if os.path.isfile(mfile):
                    return mfile
                else:
                    log.error("Expected file {:s} not found..".format(mfile))
Beispiel #7
0
def compatible(head1,head2) :
    log=get_logger()
    for k in ["PSFTYPE","NPIX_X","NPIX_Y","HSIZEX","HSIZEY","FIBERMAX","FIBERMIN","FIBERMAX","NPARAMS","LEGDEG","GHDEGX","GHDEGY"] :
        if (head1[k] != head2[k]) :
            log.warning("different %s : %s , %s"%(k,head1[k],head2[k]))
            return False
    return True
Beispiel #8
0
    def slurp(self, make_frameqa=False, remove=True, **kwargs):
        """ Slurp all the individual QA files to generate
        a list of QA_Exposure objects

        Args:
            make_frameqa: bool, optional
              Regenerate the individual QA files (at the frame level first)
            remove: bool, optional
              Remove the individual QA files?

        Returns:

        """
        from desispec.qa import QA_Exposure
        log = get_logger()
        # Remake?
        if make_frameqa:
            self.make_frameqa(**kwargs)
        # Loop on nights
        # Reset
        log.info("Resetting QA_Exposure objects")
        self.qa_exps = []
        # Loop
        for night in self.mexp_dict.keys():
            # Loop on exposures
            for exposure in self.mexp_dict[night].keys():
                frames_dict = self.mexp_dict[night][exposure]
                if len(frames_dict) == 0:
                    continue
                # Load any frame (for the type)
                qa_exp = QA_Exposure(exposure, night,
                                     specprod_dir=self.specprod_dir, remove=remove)
                # Append
                self.qa_exps.append(qa_exp)
def sim(night, nspec=25, clobber=False):
    """
    Simulate data as part of the integration test.

    Args:
        night (str): YEARMMDD
        nspec (int, optional): number of spectra to include
        clobber (bool, optional): rerun steps even if outputs already exist
        
    Raises:
        RuntimeError if any script fails
    """
    log = logging.get_logger()
    output_dir = os.path.join('$DESI_SPECTRO_REDUX','calib2d')

    # Create input fibermaps, spectra, and quickgen data

    for expid, program in zip([0,1,2], ['flat', 'arc', 'dark']):
        cmd = "newexp-random --program {program} --nspec {nspec} --night {night} --expid {expid}".format(expid=expid, program=program, nspec=nspec, night=night)

        simspec = desisim.io.findfile('simspec', night, expid)
        fibermap = '{}/fibermap-{:08d}.fits'.format(os.path.dirname(simspec),expid) 
        if runcmd(cmd, clobber=clobber) != 0:
            raise RuntimeError('newexp failed for {} exposure {}'.format(program, expid))

        cmd = "quickgen --simspec {} --fibermap {}".format(simspec,fibermap)
        if runcmd(cmd, clobber=clobber) != 0:
            raise RuntimeError('quickgen failed for {} exposure {}'.format(program, expid))

    return
Beispiel #10
0
    def _run_defaults(self):
        """See BaseTask.run_defaults.
        """
        import glob

        log = get_logger()

        opts = {}
        starmodels = None
        if "DESI_BASIS_TEMPLATES" in os.environ:
            filenames = sorted(glob.glob(os.environ["DESI_BASIS_TEMPLATES"]+"/stdstar_templates_*.fits"))
            if len(filenames) > 0 :
                starmodels = filenames[-1]
            else:
                filenames = sorted(glob.glob(os.environ["DESI_BASIS_TEMPLATES"]+"/star_templates_*.fits"))
                log.warning('Unable to find stdstar templates in {}; using star templates instead'.format(
                    os.getenv('DESI_BASIS_TEMPLATES')))
                if len(filenames) > 0 :
                    starmodels = filenames[-1]
                else:
                    msg = 'Unable to find stdstar or star templates in {}'.format(
                        os.getenv('DESI_BASIS_TEMPLATES'))
                    log.error(msg)
                    raise RuntimeError(msg)
        else:
            log.error("DESI_BASIS_TEMPLATES not set!")
            raise RuntimeError("could not find the stellar templates")

        opts["starmodels"] =  starmodels

        opts["delta-color"] = 0.2
        opts["color"] = "G-R"

        return opts
Beispiel #11
0
def main(args):

    log = get_logger()

    #- Generate obsconditions with args.program, then override as needed
    args.program = args.program.upper()
    if args.program in ['ARC', 'FLAT']:
        obsconditions = None
    else:
        obsconditions = desisim.simexp.reference_conditions[args.program]
        if args.airmass is not None:
            obsconditions['AIRMASS'] = args.airmass
        if args.seeing is not None:
            obsconditions['SEEING'] = args.seeing
        if args.exptime is not None:
            obsconditions['EXPTIME'] = args.exptime
        if args.moonfrac is not None:
            obsconditions['MOONFRAC'] = args.moonfrac
        if args.moonalt is not None:
            obsconditions['MOONALT'] = args.moonalt
        if args.moonsep is not None:
            obsconditions['MOONSEP'] = args.moonsep

    sim, fibermap, meta, obs, objmeta = desisim.obs.new_exposure(args.program,
        nspec=args.nspec, night=args.night, expid=args.expid, 
        tileid=args.tileid, nproc=args.nproc, seed=args.seed, 
        obsconditions=obsconditions, outdir=args.outdir)
Beispiel #12
0
def _overscan(pix, nsigma=5, niter=3):
    '''
    returns overscan, readnoise from overscan image pixels

    Args:
        pix (ndarray) : overscan pixels from CCD image

    Optional:
        nsigma (float) : number of standard deviations for sigma clipping
        niter (int) : number of iterative refits
    '''
    log=get_logger()
    #- normalized median absolute deviation as robust version of RMS
    #- see https://en.wikipedia.org/wiki/Median_absolute_deviation
    overscan = np.median(pix)
    absdiff = np.abs(pix - overscan)
    readnoise = 1.4826*np.median(absdiff)

    #- input pixels are integers, so iteratively refit
    for i in range(niter):
        absdiff = np.abs(pix - overscan)
        good = absdiff < nsigma*readnoise
        if np.sum(good)<5 :
            log.error("error in sigma clipping for overscan measurement, return result without clipping")
            overscan = np.median(pix)
            absdiff = np.abs(pix - overscan)
            readnoise = 1.4826*np.median(absdiff)
            return overscan,readnoise
        overscan = np.mean(pix[good])
        readnoise = np.std(pix[good])

    #- correct for bias from sigma clipping
    readnoise /= _clipped_std_bias(nsigma)

    return overscan, readnoise
Beispiel #13
0
def obj_s2n_wave(s2n_dict, wv_bins, flux_bins, otype, outfile=None, ax=None):
    """Generate QA of S/N for a given object type
    """
    logs = get_logger()
    nwv = wv_bins.size
    nfx = flux_bins.size
    s2n_sum = np.zeros((nwv-1,nfx-1))
    s2n_N = np.zeros((nwv-1,nfx-1)).astype(int)
    # Loop on exposures+wedges  (can do just once if these are identical for each)
    for jj, wave in enumerate(s2n_dict['waves']):
        w_i = np.digitize(wave, wv_bins) - 1
        m_i = np.digitize(s2n_dict['fluxes'][jj], flux_bins) - 1
        mmm = []
        for ll in range(nfx-1): # Only need to do once
            mmm.append(m_i == ll)
        #
        for kk in range(nwv-1):
            all_s2n = s2n_dict['s2n'][jj][:,w_i==kk]
            for ll in range(nfx-1):
                if np.any(mmm[ll]):
                    s2n_sum[kk, ll] += np.sum(all_s2n[mmm[ll],:])
                    s2n_N[kk, ll] += np.sum(mmm[ll]) * all_s2n.shape[1]

    sty_otype = get_sty_otype()

    # Plot
    if ax is None:
        fig = plt.figure(figsize=(6, 6.0))
        ax= plt.gca()
    # Title
    fig.suptitle('{:s}: Summary'.format(sty_otype[otype]['lbl']),
        fontsize='large')

    # Plot em up
    wv_cen = (wv_bins + np.roll(wv_bins,-1))/2.
    lstys = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
    mxy = 1e-9
    for ss in range(nfx-1):
        if np.sum(s2n_N[:,ss]) == 0:
            continue
        lbl = 'MAG = [{:0.1f},{:0.1f}]'.format(flux_bins[ss], flux_bins[ss+1])
        ax.plot(wv_cen[:-1], s2n_sum[:,ss]/s2n_N[:,ss], linestyle=lstys[ss],
                label=lbl, color=sty_otype[otype]['color'])
        mxy = max(mxy, np.max(s2n_sum[:,ss]/s2n_N[:,ss]))

    ax.set_xlabel('Wavelength (Ang)')
    #ax.set_xlim(-ylim, ylim)
    ax.set_ylabel('Mean S/N per Ang in bins of 20A')
    ax.set_yscale("log", nonposy='clip')
    ax.set_ylim(0.1, mxy*1.1)

    legend = plt.legend(loc='upper left', scatterpoints=1, borderpad=0.3,
                      handletextpad=0.3, fontsize='medium', numpoints=1)

    # Finish
    plt.tight_layout(pad=0.2,h_pad=0.2,w_pad=0.3)
    plt.subplots_adjust(top=0.92)
    if outfile is not None:
        plt.savefig(outfile, dpi=600)
        print("Wrote: {:s}".format(outfile))
Beispiel #14
0
def move_file(filename, dst):
    """Move delivered file from the DTS spool to the final raw data area.

    This function will ensure that the destination directory exists.

    Parameters
    ----------
    filename : :class:`str`
        The name, including full path, of the file to move.
    dst : :class:`str`
        The destination *directory*.

    Returns
    -------
    :class:`str`
        The value returned by :func:`shutil.move`.
    """
    from os import mkdir
    from os.path import exists, isdir
    from shutil import move
    from desiutil.log import get_logger
    log = get_logger()
    if not exists(dst):
        log.info("mkdir('{0}', 0o2770)".format(dst))
        mkdir(dst, 0o2770)
    log.info("move('{0}', '{1}')".format(filename, dst))
    return move(filename, dst)
Beispiel #15
0
def main():
    """Entry point for :command:`desi_dts_delivery`.

    Returns
    -------
    :class:`int`
        An integer suitable for passing to :func:`sys.exit`.
    """
    from os import environ
    from os.path import dirname, join
    from subprocess import Popen
    from desiutil.log import get_logger
    log = get_logger()
    options = parse_delivery()
    remote_command = ['desi_{0.nightStatus}_night {0.night}'.format(options)]
    if options.prefix is not None:
        remote_command = options.prefix + remote_command
    remote_command = ('(' +
                      '; '.join([c + ' &> /dev/null' for c in remote_command]) +
                      ' &)')
    command = ['ssh', '-n', '-q', options.nersc_host, remote_command]
    log.info("Received file {0.filename} with exposure number {0.exposure:d}.".format(options))
    dst = join(environ['DESI_SPECTRO_DATA'], options.night)
    log.info("Using {0} as raw data directory.".format(dst))
    move_file(options.filename, dst)
    exposure_arrived = check_exposure(dst, options.exposure)
    if options.nightStatus in ('start', 'end') or exposure_arrived:
        log.info("Calling: {0}.".format(' '.join(command)))
        proc = Popen(command)
    return 0
Beispiel #16
0
def load_qa_frame(filename, frame=None, flavor=None):
    """ Load an existing QA_Frame or generate one, as needed

    Args:
        filename: str
        frame: Frame object, optional
        flavor: str, optional
            Type of QA_Frame

    Returns:
        qa_frame: QA_Frame object
    """
    from desispec.qa.qa_frame import QA_Frame
    log=get_logger()
    if os.path.isfile(filename): # Read from file, if it exists
        qaframe = read_qa_frame(filename)
        log.info("Loaded QA file {:s}".format(filename))
        # Check against frame, if provided
        if frame is not None:
            for key in ['camera','expid','night','flavor']:
                assert getattr(qaframe, key) == frame.meta[key.upper()]
    else:  # Init
        if frame is None:
            log.error("QA file {:s} does not exist.  Expecting frame input".format(filename))
        qaframe = QA_Frame(frame)
    # Set flavor?
    if flavor is not None:
        qaframe.flavor = flavor
    # Return
    return qaframe
Beispiel #17
0
    def _insert(self, cursor, props):
        """See BaseTask.insert.
        """
        log = get_logger()

        name = self.name_join(props)
        colstr = '(name'
        valstr = "('{}'".format(name)

        #cmd='insert or replace into {} values ("{}"'.format(self._type, name)
        for k, ktype in zip(self._cols, self._coltypes):
            colstr += ', {}'.format(k)
            if k == "state":
                if k in props:
                    valstr += ', {}'.format(task_state_to_int[props["state"]])
                else:
                    valstr += ', {}'.format(task_state_to_int["waiting"])
            else:
                if ktype == "text":
                    valstr += ", '{}'".format(props[k])
                else:
                    valstr += ', {}'.format(props[k])
        colstr += ', submitted)'
        valstr += ', 0)'

        cmd = 'insert into {} {} values {}'.format(self._type, colstr, valstr)
        log.debug(cmd)
        cursor.execute(cmd)
        return
Beispiel #18
0
def _traceset_from_table(wavemin,wavemax,hdu,pname) :
    log=get_logger()
    head=hdu.header
    table=hdu.data
    
    extname=head["EXTNAME"]
    i=np.where(table["PARAM"]==pname)[0][0]

    if "WAVEMIN" in table.dtype.names :
        twavemin=table["WAVEMIN"][i]
        if wavemin is not None :
            if abs(twavemin-wavemin)>0.001 :
                mess="WAVEMIN not matching in hdu {} {}!={}".format(extname,twavemin,wavemin)
                log.error(mess)
                raise ValueError(mess)
        else :
            wavemin=twavemin
    
    if "WAVEMAX" in table.dtype.names :
        twavemax=table["WAVEMAX"][i]
        if wavemax is not None :
            if abs(twavemax-wavemax)>0.001 :
                mess="WAVEMAX not matching in hdu {} {}!={}".format(extname,twavemax,wavemax)
                log.error(mess)
                raise ValueError(mess)
        else :
            wavemax=twavemax
    
    log.info("read {} from hdu {}".format(pname,extname))
    return table["COEFF"][i],wavemin,wavemax 
Beispiel #19
0
def parse(options=None):
    parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    #- Required
    parser.add_argument('--fiberassign', type=str, required=True,
                        help="input fiberassign directory or tile file")
    parser.add_argument('--mockdir', type=str, required=True,
                        help="directory with mock targets and truth")
    parser.add_argument('--obslist', type=str, required=True,
                        help="input surveysim obslist file")
    parser.add_argument('--expid', type=int, required=True, help="exposure ID")

    #- Optional
    parser.add_argument('--nside', help='healpixel organization scheme of the mock spectra', type=int, default=64)
    parser.add_argument('--outdir', type=str, help="output directory")
    parser.add_argument('--nspec', type=int, default=None, help="number of spectra to include")
    parser.add_argument('--clobber', action='store_true', help="overwrite any pre-existing output files")

    log = get_logger()
    if options is None:
        args = parser.parse_args()
        log.info(' '.join(sys.argv))
    else:
        args = parser.parse_args(options)
        log.info('newexp-mock '+' '.join(options))

    return args
Beispiel #20
0
    def run(self, grph, task, opts, comm=None):
        """
        Run the PSF combining.

        This is a serial call to libspecex to combine the PSF files.

        Args:
            grph (dict): pruned graph with this task and dependencies.
            task (str): the name of this task.
            opts (dict): options to use for this task.
            comm (mpi4py.MPI.Comm): optional MPI communicator.
        """
        if comm is not None:
            if comm.size > 1:
                raise RuntimeError("PSFCombine worker should only be called with one process")

        log = get_logger()

        node = grph[task]

        outfile = graph_path(task)
        infiles = []
        for input in node["in"]:
            infiles.append(graph_path(input))

        specex.mean_psf(infiles, outfile)

        return
Beispiel #21
0
def write_xytraceset(outfile,xytraceset) :
    """
    Write a traceset fits file and returns path to file written.
    
    Args:
        outfile: full path to output file
        xytraceset:  desispec.xytraceset.XYTraceSet object
    
    Returns:
         full filepath of output file that was written    
    """

    log=get_logger()
    outfile = makepath(outfile, 'frame')
    hdus = fits.HDUList()
    x = fits.PrimaryHDU(xytraceset.x_vs_wave_traceset._coeff.astype('f4'))
    x.header['EXTNAME'] = "XTRACE"
    hdus.append(x)
    hdus.append( fits.ImageHDU(xytraceset.y_vs_wave_traceset._coeff.astype('f4'), name="YTRACE") )
    if xytraceset.xsig_vs_wave_traceset is not None : hdus.append( fits.ImageHDU(xytraceset.xsig_vs_wave_traceset._coeff.astype('f4'), name='XSIG') )
    if xytraceset.ysig_vs_wave_traceset is not None : hdus.append( fits.ImageHDU(xytraceset.ysig_vs_wave_traceset._coeff.astype('f4'), name='YSIG') )
    for hdu in ["XTRACE","YTRACE","XSIG","YSIG"] :
        if hdu in hdus :
            hdus[hdu].header["WAVEMIN"] = xytraceset.wavemin
            hdus[hdu].header["WAVEMAX"] = xytraceset.wavemax
            hdus[hdu].header["NPIX_Y"]  = xytraceset.npix_y
    hdus.writeto(outfile+'.tmp', overwrite=True, checksum=True)
    os.rename(outfile+'.tmp', outfile)
    log.info("wrote a xytraceset in {}".format(outfile))
    return outfile
Beispiel #22
0
    def __init__(self):
        self.log = get_logger()

        parser = argparse.ArgumentParser(
            description="DESI nightly processing",
            usage="""desi_night <command> [options]

Where supported commands are:
  update    Process an incoming exposure as much as possible.  Arc exposures
            will trigger PSF estimation.  If the nightly PSF exists, then a
            flat exposure will be extracted and a fiberflat will be created.
            If the nightly PSF exists, then a science exposure will be
            extracted.  If the nightly fiberflat exists, then a science
            exposure will be calibrated.
  arcs      All arcs are done, proceed with nightly PSF.
  flats     All flats are done, proceed with nightly fiberflat.
  redshifts Regroup spectra and process all updated redshifts.
""")
        parser.add_argument("command", help="Subcommand to run")
        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:2])
        if not hasattr(self, args.command):
            print("Unrecognized command")
            parser.print_help()
            sys.exit(errs["usage"])

        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)()
Beispiel #23
0
def default_options(extra={}):
    """
    Get the default options for all workers.

    Args:
        extra (dict): optional extra options to add to the
            default options for each worker class.

    Returns (dict):
        the default options dictionary, suitable for writing
        to the default options.yaml file.
    """

    log = get_logger()

    allopts = {}

    for step in step_types:
        defwork = default_workers[step]
        allopts["{}_worker".format(step)] = defwork
        if defwork in extra:
            allopts["{}_worker_opts".format(step)] = extra[defwork]
        else:
            allopts["{}_worker_opts".format(step)] = {}
        worker = get_worker(step, None, {})
        allopts[step] = worker.default_options()

    return allopts
Beispiel #24
0
 def _option_list(self, name, opts, db):
     
     # we do need db access for spectra
     if db is None :
         log = get_logger()
         log.error("we do need db access for spectra")
         raise RuntimeError("we do need db access for spectra")
     
     from .base import task_classes, task_type
     # get pixel
     props = self.name_split(name)
     # get list of exposures and spectrographs by selecting entries in the healpix_frame table with state = 1
     # which means that there is a new cframe intersecting the pixel
     entries = db.select_healpix_frame({"pixel":props["pixel"],"nside":props["nside"],"state":1})
     # now select cframe with same props
     cframes = []
     for entry in entries :
         for band in ["b","r","z"] :
             entry_and_band = entry.copy()
             entry_and_band["band"] = band
             # this will match cframes with same expid and spectro
             taskname = task_classes["cframe"].name_join(entry_and_band) 
             filename = task_classes["cframe"].paths(taskname)[0]
             cframes.append(filename)
             
     options = {}
     options["infiles"] = cframes
     options["outfile"] = self.paths(name)[0]
     options["healpix"] = props["pixel"]
     options["nside"]   = props["nside"]
             
     return option_list(options)
Beispiel #25
0
def insert_dlas(wave, zem, rstate=None, seed=None, fNHI=None, debug=False, **kwargs):
    """ Insert zero, one or more DLAs into a given spectrum towards a source
    with a given redshift
    Args:
        wave (ndarray):  wavelength array in Ang
        zem (float): quasar emission redshift
        rstate (numpy.random.rstate, optional): for random numberes
        seed (int, optional):
        fNHI (spline): f_NHI object
        **kwargs: Passed to init_fNHI()

    Returns:
        dlas (list): List of DLA dict's with keys z,N
        dla_model (ndarray): normalized specrtrum with DLAs inserted

    """
    from scipy import interpolate
    log = get_logger()
    # Init
    if rstate is None:
        rstate = np.random.RandomState(seed) # this is breaking the chain of randoms if seed is None
    if fNHI is None:
        fNHI = init_fNHI(**kwargs)

    # Allowed redshift placement
    ## Cut on zem and 910A rest-frame
    zlya = wave/1215.67 - 1
    dz = np.roll(zlya,-1)-zlya
    dz[-1] = dz[-2]
    gdz = (zlya < zem) & (wave > 910.*(1+zem))
    # l(z) -- Uses DLA for SLLS too which is fine
    lz = calc_lz(zlya[gdz])
    cum_lz = np.cumsum(lz*dz[gdz])
    tot_lz = cum_lz[-1]
    if len(cum_lz)<2:
       log.warning('WARNING: cum_lz in insert_dla  has only {} element. skyped add DLA.'.format(len(cum_lz)))
       dlas,dla_model=[],[]
       return dlas,dla_model
      
    fzdla = interpolate.interp1d(cum_lz/tot_lz, zlya[gdz],
                                 bounds_error=False,fill_value=np.min(zlya[gdz]))#
    # n DLA
    nDLA = rstate.poisson(tot_lz, 1)

    # Generate DLAs
    dlas = []
    for jj in range(nDLA[0]):
        # Random z
        zabs = float(fzdla(rstate.random_sample()))
        # Random NHI
        NHI = float(fNHI(rstate.random_sample()))
        # Generate and append
        dla = dict(z=zabs, N=NHI,dlaid=jj)
        dlas.append(dla)

    # Generate model of DLAs
    dla_model = dla_spec(wave, dlas)

    # Return
    return dlas, dla_model
Beispiel #26
0
def parse(options=None):
    parser = argparse.ArgumentParser(description="Fit of standard star spectra in frames.")
    parser.add_argument('--frames', type = str, default = None, required=True, nargs='*',
                        help = 'list of path to DESI frame fits files (needs to be same exposure, spectro)')
    parser.add_argument('--skymodels', type = str, default = None, required=True, nargs='*',
                        help = 'list of path to DESI sky model fits files (needs to be same exposure, spectro)')
    parser.add_argument('--fiberflats', type = str, default = None, required=True, nargs='*',
                        help = 'list of path to DESI fiberflats fits files (needs to be same exposure, spectro)')
    parser.add_argument('--starmodels', type = str, help = 'path of spectro-photometric stellar spectra fits')
    parser.add_argument('-o','--outfile', type = str, help = 'output file for normalized stdstar model flux')
    parser.add_argument('--ncpu', type = int, default = default_nproc, required = False, help = 'use ncpu for multiprocessing')
    parser.add_argument('--delta-color', type = float, default = 0.2, required = False, help = 'max delta-color for the selection of standard stars (on top of meas. errors)')
    parser.add_argument('--color', type = str, default = "G-R", choices=['G-R', 'R-Z'], required = False, help = 'color for selection of standard stars')
    parser.add_argument('--z-max', type = float, default = 0.008, required = False, help = 'max peculiar velocity (blue/red)shift range')
    parser.add_argument('--z-res', type = float, default = 0.00002, required = False, help = 'dz grid resolution')
    parser.add_argument('--template-error', type = float, default = 0.1, required = False, help = 'fractional template error used in chi2 computation (about 0.1 for BOSS b1)')
    
    log = get_logger()
    args = None
    if options is None:
        args = parser.parse_args()
        cmd = ' '.join(sys.argv)
    else:
        args = parser.parse_args(options)
        cmd = 'desi_fit_stdstars ' + ' '.join(options)

    log.info('RUNNING {}'.format(cmd))

    return args
Beispiel #27
0
    def __init__(self,wave,flux=None,ivar=None,mask=None,resolution=None):
        assert wave.ndim == 1, "Input wavelength should be 1D"
        assert (flux is None) or (flux.shape == wave.shape), "wave and flux should have same shape"
        assert (ivar is None) or (ivar.shape == wave.shape), "wave and ivar should have same shape"
        assert (mask is None) or (mask.shape == wave.shape), "wave and mask should have same shape"
        assert (resolution is None) or (isinstance(resolution, desispec.resolution.Resolution))
        assert (resolution is None) or (resolution.shape[0] == len(wave)), "resolution size mismatch to wave"

        self.wave = wave
        self.flux = flux
        self.ivar = ivar
        self.mask = mask
        # if mask is None:
        #     self.mask = np.zeros(self.flux.shape, dtype=np.uint32)
        # else:
        #     self.mask = util.mask32(mask)
        self.resolution = resolution
        self.R = resolution #- shorthand
        self.log = get_logger()
        # Initialize the quantities we will accumulate during co-addition. Note that our
        # internal Cinv is a dense matrix.
        if ivar is None:
            n = len(wave)
            self.Cinv = np.zeros((n,n))
            self.Cinv_f = np.zeros((n,))
        else:
            assert flux is not None and resolution is not None,'Missing flux and/or resolution.'
            diag_ivar = scipy.sparse.dia_matrix((ivar[np.newaxis,:],[0]),resolution.shape)
            self.Cinv = self.resolution.T.dot(diag_ivar.dot(self.resolution))
            self.Cinv_f = self.resolution.T.dot(self.ivar*self.flux)
def check_env():
    """
    Check required environment variables; raise RuntimeException if missing
    """
    log = logging.get_logger()
    #- template locations
    missing_env = False
    if 'DESI_BASIS_TEMPLATES' not in os.environ:
        log.warning('missing $DESI_BASIS_TEMPLATES needed for simulating spectra'.format(name))
        missing_env = True

    if not os.path.isdir(os.getenv('DESI_BASIS_TEMPLATES')):
        log.warning('missing $DESI_BASIS_TEMPLATES directory')
        log.warning('e.g. see NERSC:/project/projectdirs/desi/spectro/templates/basis_templates/v1.0')
        missing_env = True

    for name in (
        'DESI_SPECTRO_SIM', 'DESI_SPECTRO_REDUX', 'PIXPROD', 'SPECPROD', 'DESIMODEL'):
        if name not in os.environ:
            log.warning("missing ${0}".format(name))
            missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD/")
        log.warning("    Raw data read from $DESI_SPECTRO_DATA/")
        log.warning("    Spectro pipeline output written to $DESI_SPECTRO_REDUX/$SPECPROD/")
        log.warning("    Templates are read from $DESI_BASIS_TEMPLATES")

    #- Wait until end to raise exception so that we report everything that
    #- is missing before actually failing
    if missing_env:
        log.critical("missing env vars; exiting without running pipeline")
        sys.exit(1)
Beispiel #29
0
def merge_psf(inputs, output):

    log = get_logger()

    npsf = len(inputs)
    log.info("Will merge {} PSFs in {}".format(npsf,output))

    # we will add/change data to the first PSF
    psf_hdulist=fits.open(inputs[0])
    for input_filename in inputs[1:] :
        log.info("merging {} into {}".format(input_filename,inputs[0]))
        other_psf_hdulist=fits.open(input_filename)

        # look at what fibers where actually fit
        i=np.where(other_psf_hdulist["PSF"].data["PARAM"]=="STATUS")[0][0]
        status_of_fibers = \
            other_psf_hdulist["PSF"].data["COEFF"][i][:,0].astype(int)
        selected_fibers = np.where(status_of_fibers==0)[0]
        log.info("fitted fibers in PSF {} = {}".format(input_filename,
            selected_fibers))
        if selected_fibers.size == 0 :
            log.warning("no fiber with status=0 found in {}".format(
                input_filename))
            other_psf_hdulist.close()
            continue

        # copy xtrace and ytrace
        psf_hdulist["XTRACE"].data[selected_fibers] = \
            other_psf_hdulist["XTRACE"].data[selected_fibers]
        psf_hdulist["YTRACE"].data[selected_fibers] = \
            other_psf_hdulist["YTRACE"].data[selected_fibers]

        # copy parameters
        parameters = psf_hdulist["PSF"].data["PARAM"]
        for param in parameters :
            i0=np.where(psf_hdulist["PSF"].data["PARAM"]==param)[0][0]
            i1=np.where(other_psf_hdulist["PSF"].data["PARAM"]==param)[0][0]
            psf_hdulist["PSF"].data["COEFF"][i0][selected_fibers] = \
                other_psf_hdulist["PSF"].data["COEFF"][i1][selected_fibers]

        # copy bundle chi2
        i = np.where(other_psf_hdulist["PSF"].data["PARAM"]=="BUNDLE")[0][0]
        bundles = np.unique(other_psf_hdulist["PSF"].data["COEFF"][i]\
            [selected_fibers,0].astype(int))
        log.info("fitted bundles in PSF {} = {}".format(input_filename,
            bundles))
        for b in bundles :
            for key in [ "B{:02d}RCHI2".format(b), "B{:02d}NDATA".format(b),
                "B{:02d}NPAR".format(b) ]:
                psf_hdulist["PSF"].header[key] = \
                    other_psf_hdulist["PSF"].header[key]
        # close file
        other_psf_hdulist.close()

    # write
    psf_hdulist.writeto(output,overwrite=True)
    log.info("Wrote PSF {}".format(output))

    return
Beispiel #30
0
def run_task(name, opts, comm=None, logfile=None, db=None):
    """Run a single task.

    Based on the name of the task, call the appropriate run function for that
    task.  Log output to the specified file.  Run using the specified MPI
    communicator and optionally update state to the specified database.

    Note:  This function DOES NOT check the database or filesystem to see if
    the task has been completed or if its dependencies exist.  It assumes that
    some higher-level code has done that if necessary.

    Args:
        name (str): the name of this task.
        opts (dict): options to use for this task.
        comm (mpi4py.MPI.Comm): optional MPI communicator.
        logfile (str): output log file.  If None, do not redirect output to a
            file.
        db (pipeline.db.DB): The optional database to update.

    Returns:
        int: the total number of processes that failed.

    """
    from .tasks.base import task_classes, task_type
    log = get_logger()

    ttype = task_type(name)

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    if rank == 0:
        if (logfile is not None) and os.path.isfile(logfile):
            os.remove(logfile)
        # Mark task as in progress
        if db is not None:
            task_classes[ttype].state_set(db=db,name=name,state="running")

    failcount = 0
    if logfile is None:
        # No redirection
        if db is None:
            failcount = task_classes[ttype].run(name, opts, comm=comm)
        else:
            failcount = task_classes[ttype].run_and_update(db, name, opts,
                comm=comm)
    else:
        with stdouterr_redirected(to=logfile, comm=comm):
            if db is None:
                failcount = task_classes[ttype].run(name, opts, comm=comm)
            else:
                failcount = task_classes[ttype].run_and_update(db, name, opts,
                    comm=comm)

    return failcount
Beispiel #31
0
def read_frame(filename, nspec=None, skip_resolution=False):
    """Reads a frame fits file and returns its data.

    Args:
        filename: path to a file, or (night, expid, camera) tuple where
            night = string YEARMMDD
            expid = integer exposure ID
            camera = b0, r1, .. z9
        skip_resolution: bool, option
            Speed up read time (>5x) by avoiding the Resolution matrix

    Returns:
        desispec.Frame object with attributes wave, flux, ivar, etc.
    """
    log = get_logger()

    #- check if filename is (night, expid, camera) tuple instead
    if not isinstance(filename, str):
        night, expid, camera = filename
        filename = findfile('frame', night, expid, camera)

    if not os.path.isfile(filename):
        raise IOError("cannot open" + filename)

    fx = fits.open(filename, uint=True, memmap=False)
    hdr = fx[0].header
    flux = native_endian(fx['FLUX'].data.astype('f8'))
    ivar = native_endian(fx['IVAR'].data.astype('f8'))
    wave = native_endian(fx['WAVELENGTH'].data.astype('f8'))
    if 'MASK' in fx:
        mask = native_endian(fx['MASK'].data)
    else:
        mask = None  #- let the Frame object create the default mask

    # Init
    resolution_data = None
    qwsigma = None
    qndiag = None
    fibermap = None
    chi2pix = None
    scores = None
    scores_comments = None

    if skip_resolution:
        pass
    elif 'RESOLUTION' in fx:
        resolution_data = native_endian(fx['RESOLUTION'].data.astype('f8'))
    elif 'QUICKRESOLUTION' in fx:
        qr = fx['QUICKRESOLUTION'].header
        qndiag = qr['NDIAG']
        qwsigma = native_endian(fx['QUICKRESOLUTION'].data.astype('f4'))

    if 'FIBERMAP' in fx:
        fibermap = Table(fx['FIBERMAP'].data)
        if 'DESIGN_X' in fibermap.colnames:
            fibermap.rename_column('DESIGN_X', 'FIBERASSIGN_X')
        if 'DESIGN_Y' in fibermap.colnames:
            fibermap.rename_column('DESIGN_Y', 'FIBERASSIGN_Y')
    else:
        fibermap = None

    if 'CHI2PIX' in fx:
        chi2pix = native_endian(fx['CHI2PIX'].data.astype('f8'))
    else:
        chi2pix = None

    if 'SCORES' in fx:
        scores = fx['SCORES'].data
        # I need to open the header to read the comments
        scores_comments = dict()
        head = fx['SCORES'].header
        for i in range(1, len(scores.columns) + 1):
            k = 'TTYPE' + str(i)
            scores_comments[head[k]] = head.comments[k]
    else:
        scores = None
        scores_comments = None

    fx.close()

    if nspec is not None:
        flux = flux[0:nspec]
        ivar = ivar[0:nspec]
        if resolution_data is not None:
            resolution_data = resolution_data[0:nspec]
        else:
            qwsigma = qwsigma[0:nspec]
        if chi2pix is not None:
            chi2pix = chi2pix[0:nspec]
        if mask is not None:
            mask = mask[0:nspec]

    # return flux,ivar,wave,resolution_data, hdr
    frame = Frame(wave,
                  flux,
                  ivar,
                  mask,
                  resolution_data,
                  meta=hdr,
                  fibermap=fibermap,
                  chi2pix=chi2pix,
                  scores=scores,
                  scores_comments=scores_comments,
                  wsigma=qwsigma,
                  ndiag=qndiag,
                  suppress_res_warning=skip_resolution)

    # Vette
    diagnosis = frame.vet()
    if diagnosis != 0:
        warnings.warn(
            "Frame did not pass simple vetting test. diagnosis={:d}".format(
                diagnosis))
        log.error(
            "Frame did not pass simple vetting test. diagnosis={:d}".format(
                diagnosis))
    # Return
    return frame
Beispiel #32
0
def mean_psf(inputs, output):

    log = get_logger()

    npsf = len(inputs)
    log.info("Will compute the average of {} PSFs".format(npsf))

    refhead = None
    tables = []
    xtrace = []
    ytrace = []
    wavemins = []
    wavemaxs = []

    hdulist = None
    bundle_rchi2 = []
    nbundles = None
    nfibers_per_bundle = None

    for input in inputs:
        log.info("Adding {}".format(input))
        if not os.path.isfile(input):
            log.warning("missing {}".format(input))
            continue
        psf = fits.open(input)
        if refhead is None:
            hdulist = psf
            refhead = psf["PSF"].header
            nfibers = \
                (psf["PSF"].header["FIBERMAX"]-psf["PSF"].header["FIBERMIN"])+1
            PSFVER = int(refhead["PSFVER"])
            if (PSFVER < 3):
                log.error("ERROR NEED PSFVER>=3")
                sys.exit(1)

        else:
            if not compatible(psf["PSF"].header, refhead):
                log.error("psfs {} and {} are not compatible".format(
                    inputs[0], input))
                sys.exit(12)
        tables.append(psf["PSF"].data)
        wavemins.append(psf["PSF"].header["WAVEMIN"])
        wavemaxs.append(psf["PSF"].header["WAVEMAX"])

        if "XTRACE" in psf:
            xtrace.append(psf["XTRACE"].data)
        if "YTRACE" in psf:
            ytrace.append(psf["YTRACE"].data)

        rchi2 = []
        b = 0
        while "B{:02d}RCHI2".format(b) in psf["PSF"].header:
            rchi2.append(psf["PSF"].header["B{:02d}RCHI2".format(b)])
            b += 1
        rchi2 = np.array(rchi2)
        nbundles = rchi2.size
        bundle_rchi2.append(rchi2)

    npsf = len(tables)
    bundle_rchi2 = np.array(bundle_rchi2)
    log.debug("bundle_rchi2= {}".format(str(bundle_rchi2)))
    median_bundle_rchi2 = np.median(bundle_rchi2)
    rchi2_threshold = median_bundle_rchi2 + 1.
    log.debug("median chi2={} threshold={}".format(median_bundle_rchi2,
                                                   rchi2_threshold))

    WAVEMIN = refhead["WAVEMIN"]
    WAVEMAX = refhead["WAVEMAX"]
    FIBERMIN = int(refhead["FIBERMIN"])
    FIBERMAX = int(refhead["FIBERMAX"])

    fibers_in_bundle = {}
    i = np.where(tables[0]["PARAM"] == "BUNDLE")[0][0]
    bundle_of_fibers = tables[0]["COEFF"][i][:, 0].astype(int)
    bundles = np.unique(bundle_of_fibers)
    for b in bundles:
        fibers_in_bundle[b] = np.where(bundle_of_fibers == b)[0]

    for entry in range(tables[0].size):
        PARAM = tables[0][entry]["PARAM"]
        log.info("Averaging '{}' coefficients".format(PARAM))
        coeff = [tables[0][entry]["COEFF"]]
        npar = coeff[0][1].size
        for p in range(1, npsf):

            if wavemins[p] == WAVEMIN and wavemaxs[p] == WAVEMAX:
                coeff.append(tables[p][entry]["COEFF"])
            else:
                log.info("need to refit legendre polynomial ...")
                icoeff = tables[p][entry]["COEFF"]
                ocoeff = np.zeros(icoeff.shape)
                # need to reshape legpol
                iu = np.linspace(-1, 1, npar + 3)
                iwavemin = wavemins[p]
                iwavemax = wavemaxs[p]
                wave = (iu + 1.) / 2. * (iwavemax - iwavemin) + iwavemin
                ou = (wave - WAVEMIN) / (WAVEMAX - WAVEMIN) * 2. - 1.
                for f in range(icoeff.shape[0]):
                    val = legval(iu, icoeff[f])
                    ocoeff[f] = legfit(ou, val, deg=npar - 1)
                coeff.append(ocoeff)

        coeff = np.array(coeff)

        output_rchi2 = np.zeros((bundle_rchi2.shape[1]))
        output_coeff = np.zeros(tables[0][entry]["COEFF"].shape)

        # now merge, using rchi2 as selection score

        for bundle in fibers_in_bundle.keys():

            ok = np.where(bundle_rchi2[:, bundle] < rchi2_threshold)[0]
            #ok=np.array([0,1]) # debug

            if entry == 0:
                log.info("for fiber bundle {}, {} valid PSFs".format(
                    bundle, ok.size))

            if ok.size >= 2:  # use median
                log.debug("bundle #{} : use median".format(bundle))
                for f in fibers_in_bundle[bundle]:
                    output_coeff[f] = np.median(coeff[ok, f], axis=0)
                output_rchi2[bundle] = np.median(bundle_rchi2[ok, bundle])
            elif ok.size == 1:  # copy
                log.debug("bundle #{} : use only one psf ".format(bundle))
                for f in fibers_in_bundle[bundle]:
                    output_coeff[f] = coeff[ok[0], f]
                output_rchi2[bundle] = bundle_rchi2[ok[0], bundle]

            else:  # we have a problem here, take the smallest rchi2
                log.debug("bundle #{} : take smallest chi2 ".format(bundle))
                i = np.argmin(bundle_rchi2[:, bundle])
                for f in fibers_in_bundle[bundle]:
                    output_coeff[f] = coeff[i, f]
                output_rchi2[bundle] = bundle_rchi2[i, bundle]

        # now copy this in output table
        hdulist["PSF"].data["COEFF"][entry] = output_coeff
        # change bundle chi2
        for bundle in range(output_rchi2.size):
            hdulist["PSF"].header["B{:02d}RCHI2".format(bundle)] = \
                output_rchi2[bundle]

        if len(xtrace) > 0:
            xtrace = np.array(xtrace)
            ytrace = np.array(ytrace)
            for p in range(xtrace.shape[0]):
                if wavemins[p] == WAVEMIN and wavemaxs[p] == WAVEMAX:
                    continue

                # need to reshape legpol
                iu = np.linspace(-1, 1, npar + 3)
                iwavemin = wavemins[p]
                iwavemax = wavemaxs[p]
                wave = (iu + 1.) / 2. * (iwavemax - iwavemin) + iwavemin
                ou = (wave - WAVEMIN) / (WAVEMAX - WAVEMIN) * 2. - 1.

                for f in range(icoeff.shape[0]):
                    val = legval(iu, xtrace[f])
                    xtrace[f] = legfit(ou, val, deg=npar - 1)
                    val = legval(iu, ytrace[f])
                    ytrace[f] = legfit(ou, val, deg=npar - 1)

            hdulist["xtrace"].data = np.median(np.array(xtrace), axis=0)
            hdulist["ytrace"].data = np.median(np.array(ytrace), axis=0)

        # alter other keys in header
        hdulist["PSF"].header[
            "EXPID"] = 0.  # it's a mix, need to add the expids

    for hdu in ["XTRACE", "YTRACE", "PSF"]:
        if hdu in hdulist:
            for input in inputs:
                hdulist[hdu].header["comment"] = "inc {}".format(input)

    # save output PSF
    hdulist.writeto(output, overwrite=True)
    log.info("wrote {}".format(output))

    return
Beispiel #33
0
def merge_psf(inputs, output):

    log = get_logger()

    npsf = len(inputs)
    log.info("Will merge {} PSFs in {}".format(npsf, output))

    # we will add/change data to the first PSF
    psf_hdulist = fits.open(inputs[0])
    for input_filename in inputs[1:]:
        log.info("merging {} into {}".format(input_filename, inputs[0]))
        other_psf_hdulist = fits.open(input_filename)

        # look at what fibers where actually fit
        i = np.where(other_psf_hdulist["PSF"].data["PARAM"] == "STATUS")[0][0]
        status_of_fibers = \
            other_psf_hdulist["PSF"].data["COEFF"][i][:,0].astype(int)
        selected_fibers = np.where(status_of_fibers == 0)[0]
        log.info("fitted fibers in PSF {} = {}".format(input_filename,
                                                       selected_fibers))
        if selected_fibers.size == 0:
            log.warning(
                "no fiber with status=0 found in {}".format(input_filename))
            other_psf_hdulist.close()
            continue

        # copy xtrace and ytrace
        psf_hdulist["XTRACE"].data[selected_fibers] = \
            other_psf_hdulist["XTRACE"].data[selected_fibers]
        psf_hdulist["YTRACE"].data[selected_fibers] = \
            other_psf_hdulist["YTRACE"].data[selected_fibers]

        # copy parameters
        parameters = psf_hdulist["PSF"].data["PARAM"]
        for param in parameters:
            i0 = np.where(psf_hdulist["PSF"].data["PARAM"] == param)[0][0]
            i1 = np.where(
                other_psf_hdulist["PSF"].data["PARAM"] == param)[0][0]
            psf_hdulist["PSF"].data["COEFF"][i0][selected_fibers] = \
                other_psf_hdulist["PSF"].data["COEFF"][i1][selected_fibers]

        # copy bundle chi2
        i = np.where(other_psf_hdulist["PSF"].data["PARAM"] == "BUNDLE")[0][0]
        bundles = np.unique(other_psf_hdulist["PSF"].data["COEFF"][i]\
            [selected_fibers,0].astype(int))
        log.info("fitted bundles in PSF {} = {}".format(
            input_filename, bundles))
        for b in bundles:
            for key in [
                    "B{:02d}RCHI2".format(b), "B{:02d}NDATA".format(b),
                    "B{:02d}NPAR".format(b)
            ]:
                psf_hdulist["PSF"].header[key] = \
                    other_psf_hdulist["PSF"].header[key]
        # close file
        other_psf_hdulist.close()

    # write
    psf_hdulist.writeto(output, overwrite=True)
    log.info("Wrote PSF {}".format(output))

    return
Beispiel #34
0
def main(args, comm=None):

    log = get_logger()

    imgfile = args.input_image
    outfile = args.output_psf

    if args.input_psf is not None:
        inpsffile = args.input_psf
    else:
        from desispec.calibfinder import findcalibfile
        hdr = fits.getheader(imgfile)
        inpsffile = findcalibfile([
            hdr,
        ], 'PSF')

    optarray = []
    if args.extra is not None:
        optarray = args.extra.split()

    specmin = int(args.specmin)
    nspec = int(args.nspec)
    bundlesize = int(args.bundlesize)

    specmax = specmin + nspec

    # Now we divide our spectra into bundles

    checkbundles = set()
    checkbundles.update(
        np.floor_divide(np.arange(specmin, specmax),
                        bundlesize * np.ones(nspec)).astype(int))
    bundles = sorted(checkbundles)
    nbundle = len(bundles)

    bspecmin = {}
    bnspec = {}
    for b in bundles:
        if specmin > b * bundlesize:
            bspecmin[b] = specmin
        else:
            bspecmin[b] = b * bundlesize
        if (b + 1) * bundlesize > specmax:
            bnspec[b] = specmax - bspecmin[b]
        else:
            bnspec[b] = (b + 1) * bundlesize - bspecmin[b]

    # Now we assign bundles to processes

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    mynbundle = int(nbundle / nproc)
    myfirstbundle = 0
    leftover = nbundle % nproc
    if rank < leftover:
        mynbundle += 1
        myfirstbundle = rank * mynbundle
    else:
        myfirstbundle = ((mynbundle + 1) * leftover) + \
            (mynbundle * (rank - leftover))

    if rank == 0:
        # Print parameters
        log.info("specex:  using {} processes".format(nproc))
        log.info("specex:  input image = {}".format(imgfile))
        log.info("specex:  input PSF = {}".format(inpsffile))
        log.info("specex:  output = {}".format(outfile))
        log.info("specex:  bundlesize = {}".format(bundlesize))
        log.info("specex:  specmin = {}".format(specmin))
        log.info("specex:  specmax = {}".format(specmax))
        if args.broken_fibers:
            log.info("specex:  broken fibers = {}".format(args.broken_fibers))

    # get the root output file

    outpat = re.compile(r'(.*)\.fits')
    outmat = outpat.match(outfile)
    if outmat is None:
        raise RuntimeError("specex output file should have .fits extension")
    outroot = outmat.group(1)

    outdir = os.path.dirname(outroot)
    if rank == 0:
        if not os.path.isdir(outdir):
            os.makedirs(outdir)

    failcount = 0

    for b in range(myfirstbundle, myfirstbundle + mynbundle):
        outbundle = "{}_{:02d}".format(outroot, b)
        outbundlefits = "{}.fits".format(outbundle)
        com = ['desi_psf_fit']
        com.extend(['-a', imgfile])
        com.extend(['--in-psf', inpsffile])
        com.extend(['--out-psf', outbundlefits])
        com.extend(['--first-bundle', "{}".format(b)])
        com.extend(['--last-bundle', "{}".format(b)])
        com.extend(['--first-fiber', "{}".format(bspecmin[b])])
        com.extend(['--last-fiber', "{}".format(bspecmin[b] + bnspec[b] - 1)])
        if args.broken_fibers:
            com.extend(['--broken-fibers', "{}".format(args.broken_fibers)])
        if args.debug:
            com.extend(['--debug'])

        com.extend(optarray)

        log.debug("proc {} calling {}".format(rank, " ".join(com)))

        argc = len(com)
        arg_buffers = [ct.create_string_buffer(com[i].encode('ascii')) \
            for i in range(argc)]
        addrlist = [ ct.cast(x, ct.POINTER(ct.c_char)) for x in \
            map(ct.addressof, arg_buffers) ]
        arg_pointers = (ct.POINTER(ct.c_char) * argc)(*addrlist)

        retval = libspecex.cspecex_desi_psf_fit(argc, arg_pointers)

        if retval != 0:
            comstr = " ".join(com)
            log.error("desi_psf_fit on process {} failed with return "
                      "value {} running {}".format(rank, retval, comstr))
            failcount += 1

    if comm is not None:
        from mpi4py import MPI
        failcount = comm.allreduce(failcount, op=MPI.SUM)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("some bundles failed desi_psf_fit")

    if rank == 0:
        outfits = "{}.fits".format(outroot)

        inputs = ["{}_{:02d}.fits".format(outroot, x) for x in bundles]

        if args.disable_merge:
            log.info("don't merge")
        else:
            #- Empirically it appears that files written by one rank sometimes
            #- aren't fully buffer-flushed and closed before getting here,
            #- despite the MPI allreduce barrier.  Pause to let I/O catch up.
            log.info('HACK: taking a 20 sec pause before merging')
            sys.stdout.flush()
            time.sleep(20.)

            merge_psf(inputs, outfits)

            log.info('done merging')

            if failcount == 0:
                # only remove the per-bundle files if the merge was good
                for f in inputs:
                    if os.path.isfile(f):
                        os.remove(f)

    if comm is not None:
        failcount = comm.bcast(failcount, root=0)

    if failcount > 0:
        # all processes throw
        raise RuntimeError("merging of per-bundle files failed")

    return
Beispiel #35
0
def parse(options=None):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--simspec',
                        type=str,
                        required=True,
                        help="input simspec file")
    parser.add_argument('--fibermap', type=str, help='input fibermap file')
    parser.add_argument(
        '-n',
        '--nspec',
        type=int,
        default=100,
        help='number of spectra to be simulated, starting from first')
    parser.add_argument('--nstart',
                        type=int,
                        default=0,
                        help='starting spectra # for simulation 0-4999')
    parser.add_argument('--spectrograph',
                        type=int,
                        default=None,
                        help='Spectrograph no. 0-9')
    parser.add_argument('--config',
                        type=str,
                        default='desi',
                        help='specsim configuration')
    parser.add_argument('-b',
                        '--n_fibers',
                        type=int,
                        default=650,
                        help='total number of fibers')
    parser.add_argument('-t',
                        '--telescope',
                        type=str,
                        default='1m',
                        help='telescope',
                        choices=['1m', '160mm'])
    parser.add_argument('-l',
                        '--location',
                        type=str,
                        default='APO',
                        help='site location',
                        choices=['APO', 'LCO'])
    parser.add_argument('-s',
                        '--seed',
                        type=int,
                        default=0,
                        help="random seed")
    # Only produce uncalibrated output
    parser.add_argument('--frameonly',
                        action="store_true",
                        help="only output frame files")

    # Moon options if bright or gray time
    parser.add_argument('--moon-phase',
                        type=float,
                        help='moon phase (0=full, 1=new)',
                        default=None,
                        metavar='')
    parser.add_argument('--moon-angle',
                        type=float,
                        help='separation angle to the moon (0-180 deg)',
                        default=None,
                        metavar='')
    parser.add_argument('--moon-zenith',
                        type=float,
                        help='zenith angle of the moon (0-90 deg)',
                        default=None,
                        metavar='')

    parser.add_argument(
        '--objtype',
        type=str,
        help='ELG, LRG, QSO, BGS, MWS, WD, DARK_MIX, or BRIGHT_MIX',
        default='DARK_MIX',
        metavar='')
    parser.add_argument('-a',
                        '--airmass',
                        type=float,
                        help='airmass',
                        default=None,
                        metavar='')
    parser.add_argument('-e',
                        '--exptime',
                        type=float,
                        help='exposure time (s)',
                        default=None,
                        metavar='')
    parser.add_argument('-o',
                        '--outdir',
                        type=str,
                        help='output directory',
                        default='.',
                        metavar='')
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        help='toggle on verbose output')
    parser.add_argument(
        '--outdir-truth',
        type=str,
        help='optional alternative output directory for truth files',
        metavar='')

    # Object type specific options
    parser.add_argument('--zrange-qso',
                        type=float,
                        default=(0.5, 4.0),
                        nargs=2,
                        metavar='',
                        help='minimum and maximum redshift range for QSO')
    parser.add_argument('--zrange-elg',
                        type=float,
                        default=(0.6, 1.6),
                        nargs=2,
                        metavar='',
                        help='minimum and maximum redshift range for ELG')
    parser.add_argument('--zrange-lrg',
                        type=float,
                        default=(0.5, 1.1),
                        nargs=2,
                        metavar='',
                        help='minimum and maximum redshift range for LRG')
    parser.add_argument('--zrange-bgs',
                        type=float,
                        default=(0.01, 0.4),
                        nargs=2,
                        metavar='',
                        help='minimum and maximum redshift range for BGS')
    parser.add_argument(
        '--rmagrange-bgs',
        type=float,
        default=(15.0, 19.5),
        nargs=2,
        metavar='',
        help='Minimum and maximum BGS r-band (AB) magnitude range')
    parser.add_argument(
        '--sne-rfluxratiorange',
        type=float,
        default=(0.1, 1.0),
        nargs=2,
        metavar='',
        help='r-band flux ratio of the SNeIa spectrum relative to the galaxy')
    parser.add_argument('--add-SNeIa',
                        action='store_true',
                        help='include SNeIa spectra')

    if options is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(options)

    log = get_logger()
    if args.simspec:
        args.objtype = None
        if args.fibermap is None:
            dirname = os.path.dirname(os.path.abspath(args.simspec))
            filename = os.path.basename(args.simspec).replace(
                'simspec', 'fibermap')
            args.fibermap = os.path.join(dirname, filename)
            log.warning(
                'deriving fibermap {} from simspec input filename'.format(
                    args.fibermap))

    return args
Beispiel #36
0
def compute_fiber_bundle_trace_shifts_using_psf(fibers,
                                                line,
                                                psf,
                                                image,
                                                maxshift=2.):
    """
    Computes trace shifts along x and y from a preprocessed image, a PSF (with trace coords), and a given emission line,
    by doing a forward model of the image.

    Args:
        fibers : 1D array with list of fibers
        line : float, wavelength of an emission line (in Angstrom)
        psf  : specter psf object
        image : DESI preprocessed image object

    Optional:
        maxshift : float maximum shift in pixels for 2D chi2 scan

    Returns:
        x  : 1D array of x coordinates on CCD (axis=1 in numpy image array, AXIS=0 in FITS, cross-dispersion axis = fiber number direction)
        y  : 1D array of y coordinates on CCD (axis=0 in numpy image array, AXIS=1 in FITS, wavelength dispersion axis)
        dx : 1D array of shifts along x coordinates on CCD
        dy : 1D array of shifts along y coordinates on CCD
        sx : 1D array of uncertainties on dx
        sy : 1D array of uncertainties on dy
    """
    log = get_logger()
    #log.info("compute_fiber_bundle_offsets fibers={} line={}".format(fibers,line))

    # get central coordinates of bundle for interpolation of offsets on CCD
    x, y = psf.xy([
        int(np.median(fibers)),
    ], line)

    try:
        nfibers = len(fibers)

        # compute stamp coordinates
        xstart = None
        xstop = None
        ystart = None
        ystop = None
        xs = []
        ys = []
        pix = []
        xx = []
        yy = []

        for fiber in fibers:
            txs, tys, tpix = psf.xypix(fiber, line)
            xs.append(txs)
            ys.append(tys)
            pix.append(tpix)
            if xstart is None:
                xstart = txs.start
                xstop = txs.stop
                ystart = tys.start
                ystop = tys.stop
            else:
                xstart = min(xstart, txs.start)
                xstop = max(xstop, txs.stop)
                ystart = min(ystart, tys.start)
                ystop = max(ystop, tys.stop)

        # load stamp data, with margins to avoid problems with shifted psf
        margin = int(maxshift) + 1
        stamp = np.zeros(
            (ystop - ystart + 2 * margin, xstop - xstart + 2 * margin))
        stampivar = np.zeros(stamp.shape)
        stamp[margin:-margin, margin:-margin] = image.pix[ystart:ystop,
                                                          xstart:xstop]
        stampivar[margin:-margin, margin:-margin] = image.ivar[ystart:ystop,
                                                               xstart:xstop]

        # will use a fixed footprint despite changes of psf stamps
        # so that chi2 always based on same data set
        footprint = np.zeros(stamp.shape)
        for i in range(nfibers):
            footprint[margin - ystart + ys[i].start:margin - ystart +
                      ys[i].stop, margin - xstart + xs[i].start:margin -
                      xstart + xs[i].stop] = 1

        #plt.imshow(footprint) ; plt.show() ; sys.exit(12)

        # define grid of shifts to test
        res = 0.5
        nshift = int(maxshift / res)
        dx = res * np.tile(
            np.arange(2 * nshift + 1) - nshift, (2 * nshift + 1, 1))
        dy = dx.T
        original_shape = dx.shape
        dx = dx.ravel()
        dy = dy.ravel()
        chi2 = np.zeros(dx.shape)

        A = np.zeros((nfibers, nfibers))
        B = np.zeros((nfibers))
        mods = np.zeros(np.zeros(nfibers).shape + stamp.shape)

        debugging = False

        if debugging:  # FOR DEBUGGING KEEP MODELS
            models = []

        # loop on possible shifts
        # refit fluxes and compute chi2
        for d in range(len(dx)):
            # print(d,dx[d],dy[d])
            A *= 0
            B *= 0
            mods *= 0

            for i, fiber in enumerate(fibers):

                # apply the PSF shift
                psf._cache = {}  # reset cache !!
                psf.coeff['X']._coeff[fiber][0] += dx[d]
                psf.coeff['Y']._coeff[fiber][0] += dy[d]

                # compute pix and paste on stamp frame
                xx, yy, pix = psf.xypix(fiber, line)
                mods[i][margin - ystart + yy.start:margin - ystart + yy.stop,
                        margin - xstart + xx.start:margin - xstart +
                        xx.stop] = pix

                # undo the PSF shift
                psf.coeff['X']._coeff[fiber][0] -= dx[d]
                psf.coeff['Y']._coeff[fiber][0] -= dy[d]

                B[i] = np.sum(stampivar * stamp * mods[i])
                for j in range(i + 1):
                    A[i, j] = np.sum(stampivar * mods[i] * mods[j])
                    if j != i:
                        A[j, i] = A[i, j]
            Ai = np.linalg.inv(A)
            flux = Ai.dot(B)
            model = np.zeros(stamp.shape)
            for i in range(nfibers):
                model += flux[i] * mods[i]
            chi2[d] = np.sum(stampivar * (stamp - model)**2)
            if debugging:
                models.append(model)

        if debugging:
            schi2 = chi2.reshape(original_shape).copy()  # FOR DEBUGGING
            sdx = dx.copy()
            sdy = dy.copy()

        # find minimum chi2 grid point
        k = chi2.argmin()
        j, i = np.unravel_index(k, ((2 * nshift + 1), (2 * nshift + 1)))
        #print("node dx,dy=",dx.reshape(original_shape)[j,i],dy.reshape(original_shape)[j,i])

        # cut a region around minimum
        delta = 1
        istart = max(0, i - delta)
        istop = min(2 * nshift + 1, i + delta + 1)
        jstart = max(0, j - delta)
        jstop = min(2 * nshift + 1, j + delta + 1)
        chi2 = chi2.reshape(original_shape)[jstart:jstop, istart:istop].ravel()
        dx = dx.reshape(original_shape)[jstart:jstop, istart:istop].ravel()
        dy = dy.reshape(original_shape)[jstart:jstop, istart:istop].ravel()
        # fit 2D polynomial of deg2
        m = np.array([dx * 0 + 1, dx, dy, dx**2, dy**2, dx * dy]).T
        c, r, rank, s = np.linalg.lstsq(m, chi2)
        if c[3] > 0 and c[4] > 0:
            # get minimum
            # dchi2/dx=0 : c[1]+2*c[3]*dx+c[5]*dy = 0
            # dchi2/dy=0 : c[2]+2*c[4]*dy+c[5]*dx = 0
            a = np.array([[2 * c[3], c[5]], [c[5], 2 * c[4]]])
            b = np.array([c[1], c[2]])
            t = -np.linalg.inv(a).dot(b)
            dx = t[0]
            dy = t[1]
            sx = 1. / np.sqrt(c[3])
            sy = 1. / np.sqrt(c[4])
            #print("interp dx,dy=",dx,dy)

            if debugging:  # FOR DEBUGGING
                import matplotlib.pyplot as plt
                plt.figure()
                plt.subplot(2, 2, 1, title="chi2")
                plt.imshow(schi2,
                           extent=(-nshift * res, nshift * res, -nshift * res,
                                   nshift * res),
                           origin=0,
                           interpolation="nearest")
                plt.plot(dx, dy, "+", color="white", ms=20)
                plt.xlabel("x")
                plt.ylabel("y")
                plt.subplot(2, 2, 2, title="data")
                plt.imshow(stamp * footprint,
                           origin=0,
                           interpolation="nearest")
                plt.grid()
                k0 = np.argmin(sdx**2 + sdy**2)
                plt.subplot(2, 2, 3, title="original psf")
                plt.imshow(models[k0], origin=0, interpolation="nearest")
                plt.grid()
                plt.subplot(2, 2, 4, title="shifted psf")
                plt.imshow(models[k], origin=0, interpolation="nearest")
                plt.grid()
                plt.show()

        else:
            log.warning(
                "fit failed (bad chi2 surf.) for fibers [%d:%d] line=%dA" %
                (fibers[0], fibers[-1] + 1, int(line)))
            dx = 0.
            dy = 0.
            sx = 10.
            sy = 10.
    except LinAlgError:
        log.warning(
            "fit failed (masked or missing data) for fibers [%d:%d] line=%dA" %
            (fibers[0], fibers[-1] + 1, int(line)))
        dx = 0.
        dy = 0.
        sx = 10.
        sy = 10.

    return x, y, dx, dy, sx, sy
Beispiel #37
0
def quietDesiLogger(loglvl=20):
    from desiutil.log import get_logger
    get_logger(level=loglvl)
Beispiel #38
0
def compute_uniform_sky(frame,
                        nsig_clipping=4.,
                        max_iterations=100,
                        model_ivar=False,
                        add_variance=True):
    """Compute a sky model.
    
    Sky[fiber,i] = R[fiber,i,j] Flux[j]
    
    Input flux are expected to be flatfielded!
    We don't check this in this routine.

    Args:
        frame : Frame object, which includes attributes
          - wave : 1D wavelength grid in Angstroms
          - flux : 2D flux[nspec, nwave] density
          - ivar : 2D inverse variance of flux
          - mask : 2D inverse mask flux (0=good)
          - resolution_data : 3D[nspec, ndiag, nwave]  (only sky fibers)
        nsig_clipping : [optional] sigma clipping value for outlier rejection

    Optional:
        max_iterations : int , number of iterations
        model_ivar : replace ivar by a model to avoid bias due to correlated flux and ivar. this has a negligible effect on sims.
        add_variance : evaluate calibration error and add this to the sky model variance
        
    returns SkyModel object with attributes wave, flux, ivar, mask
    """

    log = get_logger()
    log.info("starting")

    # Grab sky fibers on this frame
    skyfibers = np.where(frame.fibermap['OBJTYPE'] == 'SKY')[0]
    assert np.max(skyfibers) < 500  #- indices, not fiber numbers

    nwave = frame.nwave
    nfibers = len(skyfibers)

    current_ivar = frame.ivar[skyfibers].copy() * (frame.mask[skyfibers] == 0)
    flux = frame.flux[skyfibers]
    Rsky = frame.R[skyfibers]

    input_ivar = None
    if model_ivar:
        log.info(
            "use a model of the inverse variance to remove bias due to correlated ivar and flux"
        )
        input_ivar = current_ivar.copy()
        median_ivar_vs_wave = np.median(current_ivar, axis=0)
        median_ivar_vs_fiber = np.median(current_ivar, axis=1)
        median_median_ivar = np.median(median_ivar_vs_fiber)
        for f in range(current_ivar.shape[0]):
            threshold = 0.01
            current_ivar[f] = median_ivar_vs_fiber[
                f] / median_median_ivar * median_ivar_vs_wave
            # keep input ivar for very low weights
            ii = (input_ivar[f] <= (threshold * median_ivar_vs_wave))
            #log.info("fiber {} keep {}/{} original ivars".format(f,np.sum(ii),current_ivar.shape[1]))
            current_ivar[f][ii] = input_ivar[f][ii]

    sqrtw = np.sqrt(current_ivar)
    sqrtwflux = sqrtw * flux

    chi2 = np.zeros(flux.shape)

    nout_tot = 0
    for iteration in range(max_iterations):

        # the matrix A is 1/2 of the second derivative of the chi2 with respect to the parameters
        # A_ij = 1/2 d2(chi2)/di/dj
        # A_ij = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * d(model)/dj[fiber,w]

        # the vector B is 1/2 of the first derivative of the chi2 with respect to the parameters
        # B_i  = 1/2 d(chi2)/di
        # B_i  = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * (flux[fiber,w]-model[fiber,w])

        # the model is model[fiber]=R[fiber]*sky
        # and the parameters are the unconvolved sky flux at the wavelength i

        # so, d(model)/di[fiber,w] = R[fiber][w,i]
        # this gives
        # A_ij = sum_fiber  sum_wave_w ivar[fiber,w] R[fiber][w,i] R[fiber][w,j]
        # A = sum_fiber ( diag(sqrt(ivar))*R[fiber] ) ( diag(sqrt(ivar))* R[fiber] )^t
        # A = sum_fiber sqrtwR[fiber] sqrtwR[fiber]^t
        # and
        # B = sum_fiber sum_wave_w ivar[fiber,w] R[fiber][w] * flux[fiber,w]
        # B = sum_fiber sum_wave_w sqrt(ivar)[fiber,w]*flux[fiber,w] sqrtwR[fiber,wave]

        #A=scipy.sparse.lil_matrix((nwave,nwave)).tocsr()
        A = np.zeros((nwave, nwave))
        B = np.zeros((nwave))

        # diagonal sparse matrix with content = sqrt(ivar)*flat of a given fiber
        SD = scipy.sparse.lil_matrix((nwave, nwave))

        # loop on fiber to handle resolution
        for fiber in range(nfibers):
            if fiber % 10 == 0:
                log.info("iter %d sky fiber %d/%d" %
                         (iteration, fiber, nfibers))
            R = Rsky[fiber]

            # diagonal sparse matrix with content = sqrt(ivar)
            SD.setdiag(sqrtw[fiber])

            sqrtwR = SD * R  # each row r of R is multiplied by sqrtw[r]
            A += (sqrtwR.T * sqrtwR).todense()
            B += sqrtwR.T * sqrtwflux[fiber]

        log.info("iter %d solving" % iteration)
        w = A.diagonal() > 0
        A_pos_def = A[w, :]
        A_pos_def = A_pos_def[:, w]
        parameters = B * 0
        try:
            parameters[w] = cholesky_solve(A_pos_def, B[w])
        except:
            log.info("cholesky failed, trying svd in iteration {}".format(
                iteration))
            parameters[w] = np.linalg.lstsq(A_pos_def, B[w])[0]

        log.info("iter %d compute chi2" % iteration)

        for fiber in range(nfibers):
            # the parameters are directly the unconvolve sky flux
            # so we simply have to reconvolve it
            fiber_convolved_sky_flux = Rsky[fiber].dot(parameters)
            chi2[fiber] = current_ivar[fiber] * (flux[fiber] -
                                                 fiber_convolved_sky_flux)**2

        log.info("rejecting")

        nout_iter = 0
        if iteration < 1:
            # only remove worst outlier per wave
            # apply rejection iteratively, only one entry per wave among fibers
            # find waves with outlier (fastest way)
            nout_per_wave = np.sum(chi2 > nsig_clipping**2, axis=0)
            selection = np.where(nout_per_wave > 0)[0]
            for i in selection:
                worst_entry = np.argmax(chi2[:, i])
                current_ivar[worst_entry, i] = 0
                sqrtw[worst_entry, i] = 0
                sqrtwflux[worst_entry, i] = 0
                nout_iter += 1

        else:
            # remove all of them at once
            bad = (chi2 > nsig_clipping**2)
            current_ivar *= (bad == 0)
            sqrtw *= (bad == 0)
            sqrtwflux *= (bad == 0)
            nout_iter += np.sum(bad)

        nout_tot += nout_iter

        sum_chi2 = float(np.sum(chi2))
        ndf = int(np.sum(chi2 > 0) - nwave)
        chi2pdf = 0.
        if ndf > 0:
            chi2pdf = sum_chi2 / ndf
        log.info("iter #%d chi2=%f ndf=%d chi2pdf=%f nout=%d" %
                 (iteration, sum_chi2, ndf, chi2pdf, nout_iter))

        if nout_iter == 0:
            break

    log.info("nout tot=%d" % nout_tot)

    # we know have to compute the sky model for all fibers
    # and propagate the uncertainties

    # no need to restore the original ivar to compute the model errors when modeling ivar
    # the sky inverse variances are very similar

    log.info("compute the parameter covariance")
    # we may have to use a different method to compute this
    # covariance

    try:
        parameter_covar = cholesky_invert(A)
        # the above is too slow
        # maybe invert per block, sandwich by R
    except np.linalg.linalg.LinAlgError:
        log.warning(
            "cholesky_solve_and_invert failed, switching to np.linalg.lstsq and np.linalg.pinv"
        )
        parameter_covar = np.linalg.pinv(A)

    log.info("compute mean resolution")
    # we make an approximation for the variance to save CPU time
    # we use the average resolution of all fibers in the frame:
    mean_res_data = np.mean(frame.resolution_data, axis=0)
    Rmean = Resolution(mean_res_data)

    log.info("compute convolved sky and ivar")

    # The parameters are directly the unconvolved sky
    # First convolve with average resolution :
    convolved_sky_covar = Rmean.dot(parameter_covar).dot(Rmean.T.todense())

    # and keep only the diagonal
    convolved_sky_var = np.diagonal(convolved_sky_covar)

    # inverse
    convolved_sky_ivar = (convolved_sky_var > 0) / (convolved_sky_var +
                                                    (convolved_sky_var == 0))

    # and simply consider it's the same for all spectra
    cskyivar = np.tile(convolved_sky_ivar,
                       frame.nspec).reshape(frame.nspec, nwave)

    # The sky model for each fiber (simple convolution with resolution of each fiber)
    cskyflux = np.zeros(frame.flux.shape)
    for i in range(frame.nspec):
        cskyflux[i] = frame.R[i].dot(parameters)

    # look at chi2 per wavelength and increase sky variance to reach chi2/ndf=1
    if skyfibers.size > 1 and add_variance:
        modified_cskyivar = _model_variance(frame, cskyflux, cskyivar,
                                            skyfibers)
    else:
        modified_cskyivar = cskyivar.copy()

    # need to do better here
    mask = (cskyivar == 0).astype(np.uint32)

    return SkyModel(
        frame.wave.copy(),
        cskyflux,
        modified_cskyivar,
        mask,
        nrej=nout_tot,
        stat_ivar=cskyivar)  # keep a record of the statistical ivar for QA
Beispiel #39
0
def select_targets(infiles, numproc=4, cmxdir=None, noqso=False):
    """Process input files in parallel to select commissioning (cmx) targets

    Parameters
    ----------
    infiles : :class:`list` or `str`
        List of input filenames (tractor/sweep files) OR one filename.
    numproc : :class:`int`, optional, defaults to 4
        The number of parallel processes to use.
    cmxdir : :class:`str`, optional, defaults to :envvar:`CMX_DIR`
        Directory to find commmissioning files to which to match, such
        as the CALSPEC stars. If not specified, the cmx directory is
        taken to be the value of :envvar:`CMX_DIR`.
    noqso : :class:`boolean`, optional, defaults to ``False``
        If passed, do not run the quasar selection. All QSO bits will be
        set to zero. Intended use is to speed unit tests.

    Returns
    -------
    :class:`~numpy.ndarray`
        The subset of input targets which pass the cmx cuts, including an extra
        column for `CMX_TARGET`.

    Notes
    -----
        - if numproc==1, use serial code instead of parallel.
    """
    from desiutil.log import get_logger
    log = get_logger()

    # -Convert single file to list of files.
    if isinstance(infiles, str):
        infiles = [
            infiles,
        ]

    # -Sanity check that files exist before going further.
    for filename in infiles:
        if not os.path.exists(filename):
            raise ValueError("{} doesn't exist".format(filename))

    # ADM retrieve/check the cmxdir.
    cmxdir = _get_cmxdir(cmxdir)

    def _finalize_targets(objects, cmx_target, priority_shift):
        # -desi_target includes BGS_ANY and MWS_ANY, so we can filter just
        # -on desi_target != 0
        keep = (cmx_target != 0)
        objects = objects[keep]
        cmx_target = cmx_target[keep]
        priority_shift = priority_shift[keep]

        # -Add *_target mask columns
        # ADM note that only cmx_target is defined for commissioning
        # ADM so just pass that around
        targets = finalize(objects,
                           cmx_target,
                           cmx_target,
                           cmx_target,
                           survey='cmx')
        # ADM shift the priorities of targets with functional priorities.
        targets["PRIORITY_INIT"] += priority_shift

        return targets

    # -functions to run on every brick/sweep file
    def _select_targets_file(filename):
        '''Returns targets in filename that pass the cuts'''
        objects = io.read_tractor(filename)
        cmx_target, priority_shift = apply_cuts(objects,
                                                cmxdir=cmxdir,
                                                noqso=noqso)

        return _finalize_targets(objects, cmx_target, priority_shift)

    # Counter for number of bricks processed;
    # a numpy scalar allows updating nbrick in python 2
    # c.f https://www.python.org/dev/peps/pep-3104/
    nbrick = np.zeros((), dtype='i8')

    t0 = time()

    def _update_status(result):
        ''' wrapper function for the critical reduction operation,
            that occurs on the main parallel process '''
        if nbrick % 20 == 0 and nbrick > 0:
            elapsed = time() - t0
            rate = elapsed / nbrick
            log.info(
                '{} files; {:.1f} secs/file; {:.1f} total mins elapsed'.format(
                    nbrick, rate, elapsed / 60.))
        nbrick[...] += 1  # this is an in-place modification
        return result

    # -Parallel process input files
    if numproc > 1:
        pool = sharedmem.MapReduce(np=numproc)
        with pool:
            targets = pool.map(_select_targets_file,
                               infiles,
                               reduce=_update_status)
    else:
        targets = list()
        for x in infiles:
            targets.append(_update_status(_select_targets_file(x)))

    targets = np.concatenate(targets)

    return targets
Beispiel #40
0
from datetime import datetime

from pkg_resources import resource_filename
from time import time
from astropy.io import ascii
from glob import glob
import healpy as hp

from desitarget import io
from desitarget.internal import sharedmem
from desimodel.footprint import radec2pix
from desitarget.geomask import add_hp_neighbors, radec_match_to, nside2nside

# ADM set up the DESI default logger
from desiutil.log import get_logger
log = get_logger()

# ADM start the clock
start = time()

# ADM columns contained in our version of the Tycho fits files.
tychodatamodel = np.array([], dtype=[
    ('TYC1', '>i2'), ('TYC2', '>i2'), ('TYC3', '|u1'),
    ('RA', '>f8'), ('DEC', '>f8'),
    ('MEAN_RA', '>f8'), ('MEAN_DEC', '>f8'),
    ('SIGMA_RA', '>f4'), ('SIGMA_DEC', '>f4'),
    # ADM these are converted to be in mas/yr for consistency with Gaia.
    ('PM_RA', '>f4'), ('PM_DEC', '>f4'),
    ('SIGMA_PM_RA', '>f4'), ('SIGMA_PM_DEC', '>f4'),
    ('EPOCH_RA', '>f4'), ('EPOCH_DEC', '>f4'),
    ('MAG_BT', '>f4'), ('MAG_VT', '>f4'), ('MAG_HP', '>f4'), ('ISGALAXY', '|u1'),
Beispiel #41
0
def add_missing_frames(frames):
    '''
    Adds any missing frames with ivar=0 FrameLite objects with correct shape
    to match those that do exist.

    Args:
        frames: dict of FrameLite objects, keyed by (night,expid,camera)

    Modifies `frames` in-place.

    Example: if `frames` has keys (2020,1,'b0') and (2020,1,'r0') but
    not (2020,1,'z0'), this will add a blank FrameLite object for z0.

    The purpose of this is to facilitate frames2spectra, which needs
    *something* for every spectro camera for every exposure that is included.
    '''

    log = get_logger()

    #- First figure out the number of wavelengths per band
    wave = dict()
    ndiag = dict()
    for (night, expid, camera), frame in frames.items():
        band = camera[0]
        if band not in wave:
            wave[band] = frame.wave
        if band not in ndiag:
            ndiag[band] = frame.rdat.shape[1]

    #- Now loop through all frames, filling in any missing bands
    bands = sorted(list(wave.keys()))
    for (night, expid, camera), frame in list(frames.items()):
        band = camera[0]
        spectro = camera[1:]
        for x in bands:
            if x == band:
                continue

            xcam = x + spectro
            if (night, expid, xcam) in frames:
                continue

            log.warning('Creating blank data for missing frame {}'.format(
                (night, expid, xcam)))
            nwave = len(wave[x])
            nspec = frame.flux.shape[0]
            flux = np.zeros((nspec, nwave), dtype='f4')
            ivar = np.zeros((nspec, nwave), dtype='f4')
            mask = np.zeros((nspec, nwave), dtype='u4') + specmask.NODATA
            rdat = np.zeros((nspec, ndiag[x], nwave), dtype='f4')

            #- Copy the header and correct the camera keyword
            header = fitsio.FITSHDR(frame.header)
            header['camera'] = xcam

            #- Make new blank scores, replacing trailing band _B/R/Z
            dtype = list()
            if frame.scores is not None:
                for name in frame.scores.dtype.names:
                    if name.endswith('_' + band.upper()):
                        xname = name[0:-1] + x.upper()
                        dtype.append((xname, type(frame.scores[name][0])))

                scores = np.zeros(nspec, dtype=dtype)
            else:
                scores = None

            #- Add the blank FrameLite object
            frames[(night, expid, xcam)] = FrameLite(wave[x], flux, ivar, mask,
                                                     rdat, frame.fibermap,
                                                     header, scores)
Beispiel #42
0
def stdouterr_redirected(to=None, comm=None):
    """
    Redirect stdout and stderr to a file.

    The general technique is based on:

    http://stackoverflow.com/questions/5081657
    http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/

    One difference here is that each process in the communicator
    redirects to a different temporary file, and the upon exit
    from the context the rank zero process concatenates these
    in order to the file result.

    Args:
        to (str): The output file name.
        comm (mpi4py.MPI.Comm): The optional MPI communicator.
    """

    # The currently active POSIX file descriptors
    fd_out = sys.stdout.fileno()
    fd_err = sys.stderr.fileno()

    # The DESI logger
    log = get_logger()

    def _redirect(out_to, err_to):

        # Flush the C-level buffers
        if c_stdout is not None:
            libc.fflush(c_stdout)
        if c_stderr is not None:
            libc.fflush(c_stderr)

        # This closes the python file handles, and marks the POSIX
        # file descriptors for garbage collection- UNLESS those
        # are the special file descriptors for stderr/stdout.
        sys.stdout.close()
        sys.stderr.close()

        # Close fd_out/fd_err if they are open, and copy the
        # input file descriptors to these.
        os.dup2(out_to, fd_out)
        os.dup2(err_to, fd_err)

        # Create a new sys.stdout / sys.stderr that points to the
        # redirected POSIX file descriptors.  In Python 3, these
        # are actually higher level IO objects.
        if sys.version_info[0] < 3:
            sys.stdout = os.fdopen(fd_out, "wb")
            sys.stderr = os.fdopen(fd_err, "wb")
        else:
            # Python 3 case
            sys.stdout = io.TextIOWrapper(os.fdopen(fd_out, 'wb'))
            sys.stderr = io.TextIOWrapper(os.fdopen(fd_err, 'wb'))

        # update DESI logging to use new stdout
        while len(log.handlers) > 0:
            h = log.handlers[0]
            log.removeHandler(h)
        # Add the current stdout.
        ch = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter(
            "%(levelname)s:%(filename)s:%(lineno)s:%(funcName)s: %(message)s")
        ch.setFormatter(formatter)
        log.addHandler(ch)

    # redirect both stdout and stderr to the same file

    if to is None:
        to = "/dev/null"

    if (comm is None) or (comm.rank == 0):
        log.debug("Begin log redirection to {} at {}".format(
            to, time.asctime()))

    # Save the original file descriptors so we can restore them later
    saved_fd_out = os.dup(fd_out)
    saved_fd_err = os.dup(fd_err)

    try:
        pto = to
        if comm is not None:
            if to != "/dev/null":
                pto = "{}_{}".format(to, comm.rank)

        # open python file, which creates low-level POSIX file
        # descriptor.
        file = open(pto, "w")

        # redirect stdout/stderr to this new file descriptor.
        _redirect(out_to=file.fileno(), err_to=file.fileno())

        yield  # allow code to be run with the redirected output

        # close python file handle, which will mark POSIX file
        # descriptor for garbage collection.  That is fine since
        # we are about to overwrite those in the finally clause.
        file.close()

    finally:
        # restore old stdout and stderr
        _redirect(out_to=saved_fd_out, err_to=saved_fd_err)

        if comm is not None:
            # concatenate per-process files
            comm.barrier()
            if comm.rank == 0:
                with open(to, "w") as outfile:
                    for p in range(comm.size):
                        outfile.write(
                            "================= Process {} =================\n".
                            format(p))
                        fname = "{}_{}".format(to, p)
                        with open(fname) as infile:
                            outfile.write(infile.read())
                        os.remove(fname)
            comm.barrier()

        if (comm is None) or (comm.rank == 0):
            log.debug("End log redirection to {} at {}".format(
                to, time.asctime()))

        # flush python handles for good measure
        sys.stdout.flush()
        sys.stderr.flush()

    return
Beispiel #43
0
def write_frame(outfile, frame, header=None, fibermap=None, units=None):
    """Write a frame fits file and returns path to file written.

    Args:
        outfile: full path to output file, or tuple (night, expid, channel)
        frame:  desispec.frame.Frame object with wave, flux, ivar...

    Optional:
        header: astropy.io.fits.Header or dict to override frame.header
        fibermap: table to store as FIBERMAP HDU

    Returns:
        full filepath of output file that was written

    Note:
        to create a Frame object to pass into write_frame,
        frame = Frame(wave, flux, ivar, resolution_data)
    """
    log = get_logger()
    outfile = makepath(outfile, 'frame')

    #- Ignore some known and harmless units warnings
    import warnings
    warnings.filterwarnings(
        'ignore', message="'.*nanomaggies.* did not parse as fits unit.*")
    warnings.filterwarnings(
        'ignore', message=".*'10\*\*6 arcsec.* did not parse as fits unit.*")

    if header is not None:
        hdr = fitsheader(header)
    else:
        hdr = fitsheader(frame.meta)

    add_dependencies(hdr)

    # Vette
    diagnosis = frame.vet()
    if diagnosis != 0:
        raise IOError(
            "Frame did not pass simple vetting test. diagnosis={:d}".format(
                diagnosis))

    hdus = fits.HDUList()
    x = fits.PrimaryHDU(frame.flux.astype('f4'), header=hdr)
    x.header['EXTNAME'] = 'FLUX'
    if units is not None:
        units = str(units)
        if 'BUNIT' in hdr and hdr['BUNIT'] != units:
            log.warning('BUNIT {bunit} != units {units}; using {units}'.format(
                bunit=hdr['BUNIT'], units=units))
        x.header['BUNIT'] = units
    hdus.append(x)

    hdus.append(fits.ImageHDU(frame.ivar.astype('f4'), name='IVAR'))
    # hdus.append( fits.CompImageHDU(frame.mask, name='MASK') )
    hdus.append(fits.ImageHDU(frame.mask, name='MASK'))
    hdus.append(fits.ImageHDU(frame.wave.astype('f8'), name='WAVELENGTH'))
    hdus[-1].header['BUNIT'] = 'Angstrom'
    if frame.resolution_data is not None:
        hdus.append(
            fits.ImageHDU(frame.resolution_data.astype('f4'),
                          name='RESOLUTION'))
    elif frame.wsigma is not None:
        log.debug("Using ysigma from qproc")
        qrimg = fits.ImageHDU(frame.wsigma.astype('f4'), name='YSIGMA')
        qrimg.header["NDIAG"] = frame.ndiag
        hdus.append(qrimg)
    if fibermap is not None:
        fibermap = encode_table(fibermap)  #- unicode -> bytes
        fibermap.meta['EXTNAME'] = 'FIBERMAP'
        hdus.append(fits.convenience.table_to_hdu(fibermap))
    elif frame.fibermap is not None:
        fibermap = encode_table(frame.fibermap)  #- unicode -> bytes
        fibermap.meta['EXTNAME'] = 'FIBERMAP'
        hdus.append(fits.convenience.table_to_hdu(fibermap))
    elif frame.spectrograph is not None:
        x.header[
            'FIBERMIN'] = 500 * frame.spectrograph  # Hard-coded (as in desispec.frame)
    else:
        log.error(
            "You are likely writing a frame without sufficient fiber info")

    if frame.chi2pix is not None:
        hdus.append(fits.ImageHDU(frame.chi2pix.astype('f4'), name='CHI2PIX'))

    if frame.scores is not None:
        scores_tbl = encode_table(frame.scores)  #- unicode -> bytes
        scores_tbl.meta['EXTNAME'] = 'SCORES'
        hdus.append(fits.convenience.table_to_hdu(scores_tbl))
        if frame.scores_comments is not None:  # add comments in header
            hdu = hdus['SCORES']
            for i in range(1, 999):
                key = 'TTYPE' + str(i)
                if key in hdu.header:
                    value = hdu.header[key]
                    if value in frame.scores_comments.keys():
                        hdu.header[key] = (value, frame.scores_comments[value])

    hdus.writeto(outfile + '.tmp', overwrite=True, checksum=True)

    os.rename(outfile + '.tmp', outfile)

    return outfile
Beispiel #44
0
def compute_fiberflat(frame,
                      nsig_clipping=10.,
                      accuracy=5.e-4,
                      minval=0.1,
                      maxval=10.,
                      max_iterations=100,
                      smoothing_res=5.,
                      max_bad=100,
                      max_rej_it=5,
                      min_sn=0,
                      diag_epsilon=1e-3):
    """Compute fiber flat by deriving an average spectrum and dividing all fiber data by this average.
    Input data are expected to be on the same wavelength grid, with uncorrelated noise.
    They however do not have exactly the same resolution.

    Args:
        frame (desispec.Frame): input Frame object with attributes
            wave, flux, ivar, resolution_data
        nsig_clipping : [optional] sigma clipping value for outlier rejection
        accuracy : [optional] accuracy of fiberflat (end test for the iterative loop)
        minval: [optional] mask pixels with flux < minval * median fiberflat.
        maxval: [optional] mask pixels with flux > maxval * median fiberflat.
        max_iterations: [optional] maximum number of iterations
        smoothing_res: [optional] spacing between spline fit nodes for smoothing the fiberflat
        max_bad: [optional] mask entire fiber if more than max_bad-1 initially unmasked pixels are masked during the iterations
        max_rej_it: [optional] reject at most the max_rej_it worst pixels in each iteration
        min_sn: [optional] mask portions with signal to noise less than min_sn
        diag_epsilon: [optional] size of the regularization term in the deconvolution


    Returns:
        desispec.FiberFlat object with attributes
            wave, fiberflat, ivar, mask, meanspec

    Notes:
    - we first iteratively :

       - compute a deconvolved mean spectrum
       - compute a fiber flat using the resolution convolved mean spectrum for each fiber
       - smooth the fiber flat along wavelength
       - clip outliers

    - then we compute a fiberflat at the native fiber resolution (not smoothed)

    - the routine returns the fiberflat, its inverse variance , mask, and the deconvolved mean spectrum

    - the fiberflat is the ratio data/mean , so this flat should be divided to the data

    NOTE THAT THIS CODE HAS NOT BEEN TESTED WITH ACTUAL FIBER TRANSMISSION VARIATIONS,
    OUTLIER PIXELS, DEAD COLUMNS ...
    """
    log = get_logger()
    log.info("starting")

    #
    # chi2 = sum_(fiber f) sum_(wavelenght i) w_fi ( D_fi - F_fi (R_f M)_i )
    #
    # where
    # w = inverse variance
    # D = flux data (at the resolution of the fiber)
    # F = smooth fiber flat
    # R = resolution data
    # M = mean deconvolved spectrum
    #
    # M = A^{-1} B
    # with
    # A_kl = sum_(fiber f) sum_(wavelenght i) w_fi F_fi^2 (R_fki R_fli)
    # B_k = sum_(fiber f) sum_(wavelenght i) w_fi D_fi F_fi R_fki
    #
    # defining R'_fi = sqrt(w_fi) F_fi R_fi
    # and      D'_fi = sqrt(w_fi) D_fi
    #
    # A = sum_(fiber f) R'_f R'_f^T
    # B = sum_(fiber f) R'_f D'_f
    # (it's faster that way, and we try to use sparse matrices as much as possible)
    #

    #- Shortcuts
    nwave = frame.nwave
    nfibers = frame.nspec
    wave = frame.wave.copy()  #- this will become part of output too
    flux = frame.flux.copy()
    ivar = frame.ivar * (frame.mask == 0)

    # iterative fitting and clipping to get precise mean spectrum

    # we first need to iterate to converge on a solution of mean spectrum
    # and smooth fiber flat. several interations are needed when
    # throughput AND resolution vary from fiber to fiber.
    # the end test is that the fiber flat has varied by less than accuracy
    # of previous iteration for all wavelength
    # we also have a max. number of iterations for this code

    nout_tot = 0
    chi2pdf = 0.

    smooth_fiberflat = np.ones((flux.shape))

    chi2 = np.zeros((flux.shape))

    ## mask low sn portions
    w = flux * np.sqrt(ivar) < min_sn
    ivar[w] = 0

    ## 0th pass: reject pixels according to minval and maxval
    mean_spectrum = np.zeros(flux.shape[1])
    nbad = np.zeros(nfibers, dtype=int)
    for iteration in range(max_iterations):
        for i in range(flux.shape[1]):
            w = ivar[:, i] > 0
            if w.sum() > 0:
                mean_spectrum[i] = np.median(flux[w, i])

        nbad_it = 0
        for fib in range(nfibers):
            w = ((flux[fib, :] < minval * mean_spectrum) |
                 (flux[fib, :] > maxval * mean_spectrum)) & (ivar[fib, :] > 0)
            nbad_it += w.sum()
            nbad[fib] += w.sum()

            if w.sum() > 0:
                ivar[fib, w] = 0
                log.warning("0th pass: masking {} pixels in fiber {}".format(
                    w.sum(), fib))
            if nbad[fib] >= max_bad:
                ivar[fib, :] = 0
                log.warning(
                    "0th pass: masking entire fiber {} (nbad={})".format(
                        fib, nbad[fib]))
        if nbad_it == 0:
            break

    # 1st pass is median for spectrum, flat field without resolution
    # outlier rejection
    for iteration in range(max_iterations):

        # use median for spectrum
        mean_spectrum = np.zeros((flux.shape[1]))
        for i in range(flux.shape[1]):
            w = ivar[:, i] > 0
            if w.sum() > 0:
                mean_spectrum[i] = np.median(flux[w, i])

        nbad_it = 0
        sum_chi2 = 0
        # not more than max_rej_it pixels per fiber at a time
        for fib in range(nfibers):
            w = ivar[fib, :] > 0
            if w.sum() == 0:
                continue
            F = flux[fib, :] * 0
            w = (mean_spectrum != 0) & (ivar[fib, :] > 0)
            F[w] = flux[fib, w] / mean_spectrum[w]
            smooth_fiberflat[fib, :] = spline_fit(
                wave, wave[w], F[w], smoothing_res,
                ivar[fib, w] * mean_spectrum[w]**2)
            chi2 = ivar[fib, :] * (flux[fib, :] -
                                   mean_spectrum * smooth_fiberflat[fib, :])**2
            w = np.isnan(chi2)
            bad = np.where(chi2 > nsig_clipping**2)[0]
            if bad.size > 0:
                if bad.size > max_rej_it:  # not more than 5 pixels at a time
                    ii = np.argsort(chi2[bad])
                    bad = bad[ii[-max_rej_it:]]
                ivar[fib, bad] = 0
                log.warning(
                    "1st pass: rejecting {} pixels from fiber {}".format(
                        len(bad), fib))
                nbad[fib] += len(bad)
                if nbad[fib] >= max_bad:
                    ivar[fib, :] = 0
                    log.warning(
                        "1st pass: rejecting fiber {} due to too many (new) bad pixels"
                        .format(fib))
                nbad_it += len(bad)

            sum_chi2 += chi2.sum()
        ndf = int((ivar > 0).sum() - nwave - nfibers * (nwave / smoothing_res))
        chi2pdf = 0.
        if ndf > 0:
            chi2pdf = sum_chi2 / ndf
        log.info(
            "1st pass iter #{} chi2={}/{} chi2pdf={} nout={} (nsig={})".format(
                iteration, sum_chi2, ndf, chi2pdf, nbad_it, nsig_clipping))

        if nbad_it == 0:
            break
    ## flatten fiberflat
    ## normalize smooth_fiberflat:
    mean = np.ones(smooth_fiberflat.shape[1])
    for i in range(smooth_fiberflat.shape[1]):
        w = ivar[:, i] > 0
        if w.sum() > 0:
            mean[i] = np.median(smooth_fiberflat[w, i])
    smooth_fiberflat = smooth_fiberflat / mean

    median_spectrum = mean_spectrum * 1.

    previous_smooth_fiberflat = smooth_fiberflat * 0
    log.info("after 1st pass : nout = %d/%d" %
             (np.sum(ivar == 0), np.size(ivar.flatten())))
    # 2nd pass is full solution including deconvolved spectrum, no outlier rejection
    for iteration in range(max_iterations):
        ## reset sum_chi2
        sum_chi2 = 0
        log.info("2nd pass, iter %d : mean deconvolved spectrum" % iteration)

        # fit mean spectrum
        A = scipy.sparse.lil_matrix((nwave, nwave)).tocsr()
        B = np.zeros((nwave))

        # diagonal sparse matrix with content = sqrt(ivar)*flat of a given fiber
        SD = scipy.sparse.lil_matrix((nwave, nwave))

        # this is to go a bit faster
        sqrtwflat = np.sqrt(ivar) * smooth_fiberflat

        # loop on fiber to handle resolution (this is long)
        for fiber in range(nfibers):
            if fiber % 10 == 0:
                log.info("2nd pass, filling matrix, iter %d fiber %d" %
                         (iteration, fiber))

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]
            SD.setdiag(sqrtwflat[fiber])

            sqrtwflatR = SD * R  # each row r of R is multiplied by sqrtwflat[r]

            A = A + (sqrtwflatR.T * sqrtwflatR).tocsr()
            B += sqrtwflatR.T.dot(np.sqrt(ivar[fiber]) * flux[fiber])
        A_pos_def = A.todense()
        log.info("deconvolving")
        w = A.diagonal() > 0

        A_pos_def = A_pos_def[w, :]
        A_pos_def = A_pos_def[:, w]
        mean_spectrum = np.zeros(nwave)
        try:
            mean_spectrum[w] = cholesky_solve(A_pos_def, B[w])
        except:
            mean_spectrum[w] = np.linalg.lstsq(A_pos_def, B[w])[0]
            log.info("cholesky failes, trying svd inverse in iter {}".format(
                iteration))

        for fiber in range(nfibers):

            if np.sum(ivar[fiber] > 0) == 0:
                continue

            ### R = Resolution(resolution_data[fiber])
            R = frame.R[fiber]

            M = R.dot(mean_spectrum)
            ok = (M != 0) & (ivar[fiber, :] > 0)
            if ok.sum() == 0:
                continue
            smooth_fiberflat[fiber] = spline_fit(
                wave, wave[ok], flux[fiber, ok] / M[ok], smoothing_res,
                ivar[fiber, ok] * M[ok]**2) * (ivar[fiber, :] * M**2 > 0)
            chi2 = ivar[fiber] * (flux[fiber] - smooth_fiberflat[fiber] * M)**2
            sum_chi2 += chi2.sum()
            w = np.isnan(smooth_fiberflat[fiber])
            if w.sum() > 0:
                ivar[fiber] = 0
                smooth_fiberflat[fiber] = 1

        # normalize to get a mean fiberflat=1
        mean = np.ones(smooth_fiberflat.shape[1])
        for i in range(nwave):
            w = ivar[:, i] > 0
            if w.sum() > 0:
                mean[i] = np.median(smooth_fiberflat[w, i])
        ok = np.where(mean != 0)[0]
        smooth_fiberflat[:, ok] /= mean[ok]

        # this is the max difference between two iterations
        max_diff = np.max(
            np.abs(smooth_fiberflat - previous_smooth_fiberflat) * (ivar > 0.))
        previous_smooth_fiberflat = smooth_fiberflat.copy()

        ndf = int(np.sum(ivar > 0) - nwave - nfibers * (nwave / smoothing_res))
        chi2pdf = 0.
        if ndf > 0:
            chi2pdf = sum_chi2 / ndf
        log.info("2nd pass, iter %d, chi2=%f ndf=%d chi2pdf=%f" %
                 (iteration, sum_chi2, ndf, chi2pdf))

        if max_diff < accuracy:
            break

        log.info(
            "2nd pass, iter %d, max diff. = %g > requirement = %g, continue iterating"
            % (iteration, max_diff, accuracy))

    log.info("Total number of masked pixels=%d" % nout_tot)
    log.info("3rd pass, final computation of fiber flat")

    # now use mean spectrum to compute flat field correction without any smoothing
    # because sharp feature can arise if dead columns

    fiberflat = np.ones((flux.shape))
    fiberflat_ivar = np.zeros((flux.shape))
    mask = np.zeros((flux.shape), dtype='uint32')

    # reset ivar
    ivar = frame.ivar

    fiberflat_mask = 12  # place holder for actual mask bit when defined

    nsig_for_mask = nsig_clipping  # only mask out N sigma outliers

    for fiber in range(nfibers):

        if np.sum(ivar[fiber] > 0) == 0:
            continue

        ### R = Resolution(resolution_data[fiber])
        R = frame.R[fiber]
        M = np.array(np.dot(R.todense(), mean_spectrum)).flatten()
        fiberflat[fiber] = (M != 0) * flux[fiber] / (M + (M == 0)) + (M == 0)
        fiberflat_ivar[fiber] = ivar[fiber] * M**2
        nbad_tot = 0
        iteration = 0
        while iteration < 500:
            w = fiberflat_ivar[fiber, :] > 0
            if w.sum() < 100:
                break
            smooth_fiberflat = spline_fit(wave, wave[w], fiberflat[fiber, w],
                                          smoothing_res, fiberflat_ivar[fiber,
                                                                        w])
            chi2 = fiberflat_ivar[fiber] * (fiberflat[fiber] -
                                            smooth_fiberflat)**2
            bad = np.where(chi2 > nsig_for_mask**2)[0]
            if bad.size > 0:

                nbadmax = 1
                if bad.size > nbadmax:  # not more than nbadmax pixels at a time
                    ii = np.argsort(chi2[bad])
                    bad = bad[ii[-nbadmax:]]

                mask[fiber, bad] += fiberflat_mask
                fiberflat_ivar[fiber, bad] = 0.
                nbad_tot += bad.size
            else:
                break
            iteration += 1

        log.info("3rd pass : fiber #%d , number of iterations %d" %
                 (fiber, iteration))

    # set median flat to 1
    log.info("3rd pass : set median fiberflat to 1")

    mean = np.ones((flux.shape[1]))
    for i in range(flux.shape[1]):
        ok = np.where((mask[:, i] == 0) & (ivar[:, i] > 0))[0]
        if ok.size > 0:
            mean[i] = np.median(fiberflat[ok, i])
    ok = np.where(mean != 0)[0]
    for fiber in range(nfibers):
        fiberflat[fiber, ok] /= mean[ok]

    log.info("3rd pass : interpolating over masked pixels")

    for fiber in range(nfibers):

        if np.sum(ivar[fiber] > 0) == 0:
            continue
        # replace bad by smooth fiber flat
        bad = np.where((mask[fiber] > 0) | (fiberflat_ivar[fiber] == 0)
                       | (fiberflat[fiber] < minval)
                       | (fiberflat[fiber] > maxval))[0]

        if bad.size > 0:

            fiberflat_ivar[fiber, bad] = 0

            # find max length of segment with bad pix
            length = 0
            for i in range(bad.size):
                ib = bad[i]
                ilength = 1
                tmp = ib
                for jb in bad[i + 1:]:
                    if jb == tmp + 1:
                        ilength += 1
                        tmp = jb
                    else:
                        break
                length = max(length, ilength)
            if length > 10:
                log.info(
                    "3rd pass : fiber #%d has a max length of bad pixels=%d" %
                    (fiber, length))
            smoothing_res = float(max(100, length))
            x = np.arange(wave.size)

            ok = fiberflat_ivar[fiber] > 0
            if ok.sum() == 0:
                continue
            try:
                smooth_fiberflat = spline_fit(x, x[ok], fiberflat[fiber, ok],
                                              smoothing_res,
                                              fiberflat_ivar[fiber, ok])
                fiberflat[fiber, bad] = smooth_fiberflat[bad]
            except:
                fiberflat[fiber, bad] = 1
                fiberflat_ivar[fiber, bad] = 0

        if nbad_tot > 0:
            log.info(
                "3rd pass : fiber #%d masked pixels = %d (%d iterations)" %
                (fiber, nbad_tot, iteration))

    # set median flat to 1
    log.info("set median fiberflat to 1")

    mean = np.ones((flux.shape[1]))
    for i in range(flux.shape[1]):
        ok = np.where((mask[:, i] == 0) & (ivar[:, i] > 0))[0]
        if ok.size > 0:
            mean[i] = np.median(fiberflat[ok, i])
    ok = np.where(mean != 0)[0]
    for fiber in range(nfibers):
        fiberflat[fiber, ok] /= mean[ok]

    log.info("done fiberflat")

    return FiberFlat(wave,
                     fiberflat,
                     fiberflat_ivar,
                     mask,
                     mean_spectrum,
                     chi2pdf=chi2pdf)
Beispiel #45
0
    def __init__(self,
                 wave,
                 flux,
                 ivar,
                 mask=None,
                 resolution_data=None,
                 fibers=None,
                 spectrograph=None,
                 meta=None,
                 fibermap=None,
                 chi2pix=None,
                 scores=None,
                 scores_comments=None,
                 wsigma=None,
                 ndiag=21,
                 suppress_res_warning=False):
        """
        Lightweight wrapper for multiple spectra on a common wavelength grid

        x.wave, x.flux, x.ivar, x.mask, x.resolution_data, x.header, sp.R

        Args:
            wave: 1D[nwave] wavelength in Angstroms
            flux: 2D[nspec, nwave] flux
            ivar: 2D[nspec, nwave] inverse variance of flux

        Optional:
            mask: 2D[nspec, nwave] integer bitmask of flux.  0=good.
            resolution_data: 3D[nspec, ndiag, nwave]
                             diagonals of resolution matrix data
            fibers: ndarray of which fibers these spectra are
            spectrograph: integer, which spectrograph [0-9]
            meta: dict-like object (e.g. FITS header)
            fibermap: fibermap table
            chi2pix: 2D[nspec, nwave] chi2 of 2D model to pixel-level data
                for pixels that contributed to each flux bin
            scores: dictionnary of 1D arrays of size nspec
            scores_comments: dictionnary of string (explaining the scores)
            suppress_res_warning: bool to suppress Warning message when the Resolution image is not read
        
        Parameters below allow on-the-fly resolution calculation
            wsigma: 2D[nspec,nwave] sigma widths for each wavelength bin for all fibers
        Notes:
            spectrograph input is used only if fibers is None.  In this case,
            it assumes nspec_per_spectrograph = flux.shape[0] and calculates
            the fibers array for this spectrograph, i.e.
            fibers = spectrograph * flux.shape[0] + np.arange(flux.shape[0])

        Attributes:
            All input args become object attributes.
            nspec : number of spectra, flux.shape[0]
            nwave : number of wavelengths, flux.shape[1]
            specmin : minimum fiber number
            R: array of sparse Resolution matrix objects converted
               from resolution_data
            fibermap: fibermap table if provided
        """
        assert wave.ndim == 1
        assert flux.ndim == 2
        assert wave.shape[0] == flux.shape[1]
        assert ivar.shape == flux.shape
        assert (mask is None) or mask.shape == flux.shape
        assert (mask is None) or mask.dtype in \
            (int, np.int64, np.int32, np.uint64, np.uint32), "Bad mask type "+str(mask.dtype)

        self.wave = wave
        self.flux = flux
        self.ivar = ivar
        self.meta = meta
        self.fibermap = fibermap
        self.nspec, self.nwave = self.flux.shape
        self.chi2pix = chi2pix
        self.scores = scores
        self.scores_comments = scores_comments
        self.ndiag = ndiag
        fibers_per_spectrograph = 500  #- hardcode; could get from desimodel

        if mask is None:
            self.mask = np.zeros(flux.shape, dtype=np.uint32)
        else:
            self.mask = util.mask32(mask)

        if resolution_data is not None:
            if resolution_data.ndim != 3 or \
               resolution_data.shape[0] != self.nspec or \
               resolution_data.shape[2] != self.nwave:
                raise ValueError(
                    "Wrong dimensions for resolution_data[nspec, ndiag, nwave]"
                )

        #- Maybe setup non-None identity matrix resolution matrix instead?
        self.wsigma = wsigma
        self.resolution_data = resolution_data
        if resolution_data is not None:
            self.wsigma = None  #ignore width coefficients if resolution data is given explicitly
            self.ndiag = None
            self.R = np.array([Resolution(r) for r in resolution_data])
        elif wsigma is not None:
            from desispec.quicklook.qlresolution import QuickResolution
            assert ndiag is not None
            r = []
            for sigma in wsigma:
                r.append(QuickResolution(sigma=sigma, ndiag=self.ndiag))
            self.R = np.array(r)
        else:
            #SK I believe this should be error, but looking at the
            #tests frame objects are allowed to not to have resolution data
            # thus I changed value error to a simple warning message.
            if not suppress_res_warning:
                log = get_logger()
                log.warning("Frame object is constructed without resolution data or respective "\
                        "sigma widths. Resolution will not be available")
            # raise ValueError("Need either resolution_data or coefficients to generate it")
        self.spectrograph = spectrograph

        # Deal with Fibers (these must be set!)
        if fibers is not None:
            fibers = np.asarray(fibers)
            if len(fibers) != self.nspec:
                raise ValueError("len(fibers) != nspec ({} != {})".format(
                    len(fibers), self.nspec))
            if fibermap is not None and np.any(fibers != fibermap['FIBER']):
                raise ValueError("fibermap doesn't match fibers")
            if (spectrograph is not None):
                minfiber = spectrograph * fibers_per_spectrograph
                maxfiber = (spectrograph + 1) * fibers_per_spectrograph
                if np.any(fibers < minfiber) or np.any(maxfiber <= fibers):
                    raise ValueError('fibers inconsistent with spectrograph')
            self.fibers = fibers
        else:
            if fibermap is not None:
                self.fibers = fibermap['FIBER']
            elif spectrograph is not None:
                self.fibers = spectrograph * fibers_per_spectrograph + np.arange(
                    self.nspec, dtype=int)
            elif (self.meta is not None) and ('FIBERMIN' in self.meta):
                self.fibers = self.meta['FIBERMIN'] + np.arange(self.nspec,
                                                                dtype=int)
            else:
                raise ValueError("Must set fibers by one of the methods!")

        if self.meta is not None:
            self.meta['FIBERMIN'] = np.min(self.fibers)
Beispiel #46
0
def qa_fiberflat(param, frame, fiberflat):
    """ Calculate QA on FiberFlat object

    Args:
        param: dict of QA parameters
        frame: Frame
        fiberflat: FiberFlat

    Returns:
        qadict: dict of QA outputs
          Need to record simple Python objects for yaml (str, float, int)
    """
    from desimodel.focalplane import fiber_area_arcsec2
    log = get_logger()

    # x, y, area
    fibermap = frame.fibermap
    x = fibermap['X_TARGET']
    y = fibermap['Y_TARGET']
    area = fiber_area_arcsec2(x, y)
    mean_area = np.mean(area)
    norm_area = area / mean_area
    npix = fiberflat.fiberflat.shape[1]

    # Normalize
    norm_flat = fiberflat.fiberflat / np.outer(norm_area, np.ones(npix))

    # Output dict
    qadict = {}

    # Check amplitude of the meanspectrum
    qadict['MAX_MEANSPEC'] = float(np.max(fiberflat.meanspec))
    if qadict['MAX_MEANSPEC'] < 100000:
        log.warning("Low counts in meanspec = {:g}".format(
            qadict['MAX_MEANSPEC']))

    # Record chi2pdf
    try:
        qadict['CHI2PDF'] = float(fiberflat.chi2pdf)
    except TypeError:
        qadict['CHI2PDF'] = 0.

    # N mask
    qadict['N_MASK'] = int(np.sum(fiberflat.mask > 0))
    if qadict['N_MASK'] > param['MAX_N_MASK']:  # Arbitrary
        log.warning("High rejection rate: {:d}".format(qadict['N_MASK']))

    # Scale (search for low/high throughput)
    gdp = fiberflat.mask == 0
    rtio = (frame.flux / np.outer(norm_area, np.ones(npix))) / np.outer(
        np.ones(fiberflat.nspec), fiberflat.meanspec)
    scale = np.median(rtio * gdp, axis=1)
    MAX_SCALE_OFF = float(np.max(np.abs(scale - 1.)))
    fiber = int(np.argmax(np.abs(scale - 1.)))
    qadict['MAX_SCALE_OFF'] = [MAX_SCALE_OFF, fiber]
    if qadict['MAX_SCALE_OFF'][0] > param['MAX_SCALE_OFF']:
        log.warning("Discrepant flux in fiberflat: {:g}, {:d}".format(
            qadict['MAX_SCALE_OFF'][0], qadict['MAX_SCALE_OFF'][1]))

    # Offset in fiberflat
    qadict['MAX_OFF'] = float(np.max(np.abs(norm_flat - 1.)))
    if qadict['MAX_OFF'] > param['MAX_OFF']:
        log.warning("Large offset in fiberflat: {:g}".format(
            qadict['MAX_OFF']))

    # Offset in mean of fiberflat
    #mean = np.mean(fiberflat.fiberflat*gdp,axis=1)
    mean = np.mean(norm_flat * gdp, axis=1)
    fiber = int(np.argmax(np.abs(mean - 1.)))
    qadict['MAX_MEAN_OFF'] = [float(np.max(np.abs(mean - 1.))), fiber]
    if qadict['MAX_MEAN_OFF'][0] > param['MAX_MEAN_OFF']:
        log.warning("Discrepant mean in fiberflat: {:g}, {:d}".format(
            qadict['MAX_MEAN_OFF'][0], qadict['MAX_MEAN_OFF'][1]))

    # RMS in individual fibers
    rms = np.std(gdp * (norm_flat - np.outer(mean, np.ones(fiberflat.nwave))),
                 axis=1)
    #rms = np.std(gdp*(fiberflat.fiberflat-
    #                  np.outer(mean, np.ones(fiberflat.nwave))),axis=1)
    fiber = int(np.argmax(rms))
    qadict['MAX_RMS'] = [float(np.max(rms)), fiber]
    if qadict['MAX_RMS'][0] > param['MAX_RMS']:
        log.warning("Large RMS in fiberflat: {:g}, {:d}".format(
            qadict['MAX_RMS'][0], qadict['MAX_RMS'][1]))

    # Return
    return qadict
Beispiel #47
0
def get_fiberbitmasked_frame_arrays(frame,
                                    bitmask=None,
                                    ivar_framemask=True,
                                    return_mask=False):
    """
    Function that takes a frame object and a bitmask and
    returns ivar (and optionally mask) array(s) that have fibers with
    offending bits in fibermap['FIBERSTATUS'] set to
    0 in ivar and optionally flips a bit in mask.

    input:
        frame: frame object
        bitmask: int32 or list/array of int32's derived from desispec.maskbits.fibermask
                 OR string indicating a keyword for get_fiberbitmask_comparison_value()
        ivar_framemask: bool (default=True), tells code whether to multiply the output
                 variance by (frame.mask==0)
        return_mask: bool, (default=False). Returns the frame.mask with the logic of
                 FIBERSTATUS applied.

    output:
        ivar: frame.ivar where the fibers with FIBERSTATUS & bitmask > 0
              set to zero ivar
        mask: (optional) frame.mask logically OR'ed with BADFIBER bit in cases with
              a bad FIBERSTATUS

    example bitmask list:
        bitmask = [fmsk.BROKENFIBER,fmsk.UNASSIGNED,fmsk.BADFIBER,\
                    fmsk.BADTRACE,fmsk.MANYBADCOL, fmsk.MANYREJECTED]
        bitmask = get_fiberbitmask_comparison_value(kind='fluxcalib')
        bitmask = 'fluxcalib'
        bitmask = 4128780
    """
    ivar = frame.ivar.copy()
    mask = frame.mask.copy()
    if ivar_framemask and frame.mask is not None:
        ivar *= (frame.mask == 0)

    fmap = Table(frame.fibermap)

    if frame.fibermap is None:
        log = get_logger()
        log.warning("No fibermap was given, so no FIBERSTATUS check applied.")

    if bitmask is None or frame.fibermap is None:
        if return_mask:
            return ivar, mask
        else:
            return ivar

    if type(bitmask) in [int, np.int32]:
        bad = bitmask
    elif type(bitmask) == str:
        if bitmask.isnumeric():
            bad = np.int32(bitmask)
        else:
            bad = get_fiberbitmask_comparison_value(kind=bitmask)
    else:
        bad = bitmask[0]
        for bit in bitmask[1:]:
            bad |= bit

    # find if any fibers have an intersection with the bad bits
    badfibers = fmap['FIBER'][(fmap['FIBERSTATUS'] & bad) > 0].data
    badfibers = badfibers % 500
    # For the bad fibers, loop through and nullify them
    for fiber in badfibers:
        mask[fiber] |= specmask.BADFIBER
        if ivar_framemask:
            ivar[fiber] = 0.

    if return_mask:
        return ivar, mask
    else:
        return ivar
Beispiel #48
0
def average_fiberflat(fiberflats):
    """Average several fiberflats 
    Args:
        fiberflats : list of `desispec.FiberFlat` object

    returns a desispec.FiberFlat object
    """

    log = get_logger()
    log.info("starting")

    if len(fiberflats) == 0:
        message = "input fiberflat list is empty"
        log.critical(message)
        raise ValueError(message)
    if len(fiberflats) == 1:
        log.warning("only one fiberflat to average??")
        return fiberflats[0]

    # check wavelength range
    for fflat in fiberflats[1:]:
        if not np.allclose(fiberflats[0].wave, fflat.wave):
            message = "fiberflats do not have the same wavelength arrays"
            log.critical(message)
            raise ValueError(message)
    wave = fiberflats[0].wave

    fiberflat = None
    ivar = None
    if len(fiberflats) > 2:
        log.info("{} fiberflat to average, use masked median".format(
            len(fiberflats)))
        tmp_fflat = []
        tmp_ivar = []
        tmp_mask = []
        for tmp in fiberflats:
            tmp_fflat.append(tmp.fiberflat)
            tmp_ivar.append(tmp.ivar)
            tmp_mask.append(tmp.mask)
        fiberflat = masked_median(np.array(tmp_fflat), np.array(tmp_mask))
        ivar = np.sum(np.array(tmp_ivar), axis=0)
        ivar *= 2. / np.pi  # penalty for using a median instead of a mean
    else:
        log.info("{} fiberflat to average, use weighted mean".format(
            len(fiberflats)))
        sw = None
        swf = None
        for tmp in fiberflats:
            w = (tmp.ivar) * (tmp.mask == 0)
            if sw is None:
                sw = w
                swf = w * tmp.fiberflat
                mask = tmp.mask
            else:
                sw += w
                swf += w * tmp.fiberflat
        fiberflat = swf / (sw + (sw == 0))
        ivar = sw

    # combined mask
    mask = None
    for tmp in fiberflats:
        if mask is None:
            mask = tmp.mask
        else:
            ii = (mask > 0) & (tmp.mask > 0)
            mask[ii] |= tmp.mask[ii]
            mask[
                tmp.mask ==
                0] = 0  # mask=0 on fiber and wave data point where at list one fiberflat has mask=0

    return FiberFlat(wave,
                     fiberflat,
                     ivar,
                     mask,
                     header=fiberflats[0].header,
                     fibers=fiberflats[0].fibers,
                     spectrograph=fiberflats[0].spectrograph)
Beispiel #49
0
def main(args):

    # Set up the logger
    if args.verbose:
        log = get_logger(DEBUG)
    else:
        log = get_logger()

    # Make sure all necessary environment variables are set
    DESI_SPECTRO_REDUX_DIR = "./quickGen"

    if 'DESI_SPECTRO_REDUX' not in os.environ:

        log.info('DESI_SPECTRO_REDUX environment is not set.')

    else:
        DESI_SPECTRO_REDUX_DIR = os.environ['DESI_SPECTRO_REDUX']

    if os.path.exists(DESI_SPECTRO_REDUX_DIR):

        if not os.path.isdir(DESI_SPECTRO_REDUX_DIR):
            raise RuntimeError("Path %s Not a directory" %
                               DESI_SPECTRO_REDUX_DIR)
    else:
        try:
            os.makedirs(DESI_SPECTRO_REDUX_DIR)
        except:
            raise

    SPECPROD_DIR = 'specprod'
    if 'SPECPROD' not in os.environ:
        log.info('SPECPROD environment is not set.')
    else:
        SPECPROD_DIR = os.environ['SPECPROD']
    prod_Dir = specprod_root()

    if os.path.exists(prod_Dir):

        if not os.path.isdir(prod_Dir):
            raise RuntimeError("Path %s Not a directory" % prod_Dir)
    else:
        try:
            os.makedirs(prod_Dir)
        except:
            raise

    # Initialize random number generator to use.
    np.random.seed(args.seed)
    random_state = np.random.RandomState(args.seed)

    # Derive spectrograph number from nstart if needed
    if args.spectrograph is None:
        args.spectrograph = args.nstart / args.n_fibers

    # Read fibermapfile to get object type, night and expid
    if args.fibermap:
        log.info("Reading fibermap file {}".format(args.fibermap))
        fibermap = read_fibermap(args.fibermap)
        objtype = get_source_types(fibermap)
        stdindx = np.where(objtype == 'STD')  # match STD with STAR
        mwsindx = np.where(objtype == 'MWS_STAR')  # match MWS_STAR with STAR
        bgsindx = np.where(objtype == 'BGS')  # match BGS with LRG
        objtype[stdindx] = 'STAR'
        objtype[mwsindx] = 'STAR'
        objtype[bgsindx] = 'LRG'
        NIGHT = fibermap.meta['NIGHT']
        EXPID = fibermap.meta['EXPID']
        TILEID = fibermap.meta['TILEID']
    else:
        # Create a blank fake fibermap
        fibermap = empty_fibermap(args.nspec)
        targetids = random_state.randint(2**62, size=args.nspec)
        fibermap['TARGETID'] = targetids
        night = get_night()
        expid = 0

    log.info("Initializing SpecSim with config {}".format(args.config))
    desiparams = load_desiparams(config=args.config, telescope=args.telescope)
    qsim = get_simulator(args.config, num_fibers=1, params=desiparams)

    if args.simspec:
        # Read the input file
        log.info('Reading input file {}'.format(args.simspec))
        simspec = desisim.io.read_simspec(args.simspec)
        nspec = simspec.nspec
        if simspec.flavor == 'arc':
            log.warning("quickgen doesn't generate flavor=arc outputs")
            return
        else:
            wavelengths = simspec.wave
            spectra = simspec.flux
        if nspec < args.nspec:
            log.info("Only {} spectra in input file".format(nspec))
            args.nspec = nspec

    else:
        # Initialize the output truth table.
        spectra = []
        wavelengths = qsim.source.wavelength_out.to(u.Angstrom).value
        npix = len(wavelengths)
        truth = dict()
        meta = Table()
        truth['OBJTYPE'] = np.zeros(args.nspec, dtype=(str, 10))
        truth['FLUX'] = np.zeros((args.nspec, npix))
        truth['WAVE'] = wavelengths
        jj = list()

        for thisobj in set(true_objtype):
            ii = np.where(true_objtype == thisobj)[0]
            nobj = len(ii)
            truth['OBJTYPE'][ii] = thisobj
            log.info('Generating {} template'.format(thisobj))

            # Generate the templates
            if thisobj == 'ELG':
                elg = desisim.templates.ELG(wave=wavelengths,
                                            add_SNeIa=args.add_SNeIa)
                flux, tmpwave, meta1 = elg.make_templates(
                    nmodel=nobj,
                    seed=args.seed,
                    zrange=args.zrange_elg,
                    sne_rfluxratiorange=args.sne_rfluxratiorange)
            elif thisobj == 'LRG':
                lrg = desisim.templates.LRG(wave=wavelengths,
                                            add_SNeIa=args.add_SNeIa)
                flux, tmpwave, meta1 = lrg.make_templates(
                    nmodel=nobj,
                    seed=args.seed,
                    zrange=args.zrange_lrg,
                    sne_rfluxratiorange=args.sne_rfluxratiorange)
            elif thisobj == 'QSO':
                qso = desisim.templates.QSO(wave=wavelengths)
                flux, tmpwave, meta1 = qso.make_templates(
                    nmodel=nobj, seed=args.seed, zrange=args.zrange_qso)
            elif thisobj == 'BGS':
                bgs = desisim.templates.BGS(wave=wavelengths,
                                            add_SNeIa=args.add_SNeIa)
                flux, tmpwave, meta1 = bgs.make_templates(
                    nmodel=nobj,
                    seed=args.seed,
                    zrange=args.zrange_bgs,
                    rmagrange=args.rmagrange_bgs,
                    sne_rfluxratiorange=args.sne_rfluxratiorange)
            elif thisobj == 'STD':
                std = desisim.templates.STD(wave=wavelengths)
                flux, tmpwave, meta1 = std.make_templates(nmodel=nobj,
                                                          seed=args.seed)
            elif thisobj == 'QSO_BAD':  # use STAR template no color cuts
                star = desisim.templates.STAR(wave=wavelengths)
                flux, tmpwave, meta1 = star.make_templates(nmodel=nobj,
                                                           seed=args.seed)
            elif thisobj == 'MWS_STAR' or thisobj == 'MWS':
                mwsstar = desisim.templates.MWS_STAR(wave=wavelengths)
                flux, tmpwave, meta1 = mwsstar.make_templates(nmodel=nobj,
                                                              seed=args.seed)
            elif thisobj == 'WD':
                wd = desisim.templates.WD(wave=wavelengths)
                flux, tmpwave, meta1 = wd.make_templates(nmodel=nobj,
                                                         seed=args.seed)
            elif thisobj == 'SKY':
                flux = np.zeros((nobj, npix))
                meta1 = Table(dict(REDSHIFT=np.zeros(nobj, dtype=np.float32)))
            elif thisobj == 'TEST':
                flux = np.zeros((args.nspec, npix))
                indx = np.where(wave > 5800.0 - 1E-6)[0][0]
                ref_integrated_flux = 1E-10
                ref_cst_flux_density = 1E-17
                single_line = (np.arange(args.nspec) % 2 == 0).astype(
                    np.float32)
                continuum = (np.arange(args.nspec) % 2 == 1).astype(np.float32)

                for spec in range(args.nspec):
                    flux[spec, indx] = single_line[
                        spec] * ref_integrated_flux / np.gradient(wavelengths)[
                            indx]  # single line
                    flux[spec] += continuum[
                        spec] * ref_cst_flux_density  # flat continuum

                meta1 = Table(
                    dict(REDSHIFT=np.zeros(args.nspec, dtype=np.float32),
                         LINE=wave[indx] *
                         np.ones(args.nspec, dtype=np.float32),
                         LINEFLUX=single_line * ref_integrated_flux,
                         CONSTFLUXDENSITY=continuum * ref_cst_flux_density))
            else:
                log.fatal('Unknown object type {}'.format(thisobj))
                sys.exit(1)

            # Pack it in.
            truth['FLUX'][ii] = flux
            meta = vstack([meta, meta1])
            jj.append(ii.tolist())

            # Sanity check on units; templates currently return ergs, not 1e-17 ergs...
            # assert (thisobj == 'SKY') or (np.max(truth['FLUX']) < 1e-6)

        # Sort the metadata table.
        jj = sum(jj, [])
        meta_new = Table()
        for k in range(args.nspec):
            index = int(np.where(np.array(jj) == k)[0])
            meta_new = vstack([meta_new, meta[index]])
        meta = meta_new

        # Add TARGETID and the true OBJTYPE to the metadata table.
        meta.add_column(
            Column(true_objtype, dtype=(str, 10), name='TRUE_OBJTYPE'))
        meta.add_column(Column(targetids, name='TARGETID'))

        # Rename REDSHIFT -> TRUEZ anticipating later table joins with zbest.Z
        meta.rename_column('REDSHIFT', 'TRUEZ')

    # explicitly set location on focal plane if needed to support airmass
    # variations when using specsim v0.5
    if qsim.source.focal_xy is None:
        qsim.source.focal_xy = (u.Quantity(0, 'mm'), u.Quantity(100, 'mm'))

    # Set simulation parameters from the simspec header or desiparams
    bright_objects = ['bgs', 'mws', 'bright', 'BGS', 'MWS', 'BRIGHT_MIX']
    gray_objects = ['gray', 'grey']
    if args.simspec is None:
        object_type = objtype
        flavor = None
    elif simspec.flavor == 'science':
        object_type = None
        flavor = simspec.header['PROGRAM']
    else:
        object_type = None
        flavor = simspec.flavor
        log.warning(
            'Maybe using an outdated simspec file with flavor={}'.format(
                flavor))

    # Set airmass
    if args.airmass is not None:
        qsim.atmosphere.airmass = args.airmass
    elif args.simspec and 'AIRMASS' in simspec.header:
        qsim.atmosphere.airmass = simspec.header['AIRMASS']
    else:
        qsim.atmosphere.airmass = 1.25  # Science Req. Doc L3.3.2

    # Set site location
    if args.location is not None:
        qsim.observation.observatory = args.location
    else:
        qsim.observation.observatory = 'APO'

    # Set exptime
    if args.exptime is not None:
        qsim.observation.exposure_time = args.exptime * u.s
    elif args.simspec and 'EXPTIME' in simspec.header:
        qsim.observation.exposure_time = simspec.header['EXPTIME'] * u.s
    elif objtype in bright_objects:
        qsim.observation.exposure_time = desiparams['exptime_bright'] * u.s
    else:
        qsim.observation.exposure_time = desiparams['exptime_dark'] * u.s

    # Set Moon Phase
    if args.moon_phase is not None:
        qsim.atmosphere.moon.moon_phase = args.moon_phase
    elif args.simspec and 'MOONFRAC' in simspec.header:
        qsim.atmosphere.moon.moon_phase = simspec.header['MOONFRAC']
    elif flavor in bright_objects or object_type in bright_objects:
        qsim.atmosphere.moon.moon_phase = 0.7
    elif flavor in gray_objects:
        qsim.atmosphere.moon.moon_phase = 0.1
    else:
        qsim.atmosphere.moon.moon_phase = 0.5

    # Set Moon Zenith
    if args.moon_zenith is not None:
        qsim.atmosphere.moon.moon_zenith = args.moon_zenith * u.deg
    elif args.simspec and 'MOONALT' in simspec.header:
        qsim.atmosphere.moon.moon_zenith = simspec.header['MOONALT'] * u.deg
    elif flavor in bright_objects or object_type in bright_objects:
        qsim.atmosphere.moon.moon_zenith = 30 * u.deg
    elif flavor in gray_objects:
        qsim.atmosphere.moon.moon_zenith = 80 * u.deg
    else:
        qsim.atmosphere.moon.moon_zenith = 100 * u.deg

    # Set Moon - Object Angle
    if args.moon_angle is not None:
        qsim.atmosphere.moon.separation_angle = args.moon_angle * u.deg
    elif args.simspec and 'MOONSEP' in simspec.header:
        qsim.atmosphere.moon.separation_angle = simspec.header[
            'MOONSEP'] * u.deg
    elif flavor in bright_objects or object_type in bright_objects:
        qsim.atmosphere.moon.separation_angle = 50 * u.deg
    elif flavor in gray_objects:
        qsim.atmosphere.moon.separation_angle = 60 * u.deg
    else:
        qsim.atmosphere.moon.separation_angle = 60 * u.deg

    # Initialize per-camera output arrays that will be saved
    waves, trueflux, noisyflux, obsivar, resolution, sflux = {}, {}, {}, {}, {}, {}

    maxbin = 0
    nmax = args.nspec
    for camera in qsim.instrument.cameras:
        # Lookup this camera's resolution matrix and convert to the sparse
        # format used in desispec.
        R = Resolution(camera.get_output_resolution_matrix())
        resolution[camera.name] = np.tile(R.to_fits_array(),
                                          [args.nspec, 1, 1])
        waves[camera.name] = (camera.output_wavelength.to(
            u.Angstrom).value.astype(np.float32))
        nwave = len(waves[camera.name])
        maxbin = max(maxbin, len(waves[camera.name]))
        nobj = np.zeros((nmax, 3, maxbin))  # object photons
        nsky = np.zeros((nmax, 3, maxbin))  # sky photons
        nivar = np.zeros((nmax, 3, maxbin))  # inverse variance (object+sky)
        cframe_observedflux = np.zeros(
            (nmax, 3, maxbin))  # calibrated object flux
        cframe_ivar = np.zeros(
            (nmax, 3, maxbin))  # inverse variance of calibrated object flux
        cframe_rand_noise = np.zeros(
            (nmax, 3, maxbin))  # random Gaussian noise to calibrated flux
        sky_ivar = np.zeros((nmax, 3, maxbin))  # inverse variance of sky
        sky_rand_noise = np.zeros(
            (nmax, 3, maxbin))  # random Gaussian noise to sky only
        frame_rand_noise = np.zeros(
            (nmax, 3, maxbin))  # random Gaussian noise to nobj+nsky
        trueflux[camera.name] = np.empty(
            (args.nspec, nwave))  # calibrated flux
        noisyflux[camera.name] = np.empty(
            (args.nspec, nwave))  # observed flux with noise
        obsivar[camera.name] = np.empty(
            (args.nspec, nwave))  # inverse variance of flux
        if args.simspec:
            for i in range(10):
                cn = camera.name + str(i)
                if cn in simspec.cameras:
                    dw = np.gradient(simspec.cameras[cn].wave)
                    break
            else:
                raise RuntimeError(
                    'Unable to find a {} camera in input simspec'.format(
                        camera))
        else:
            sflux = np.empty((args.nspec, npix))

    #- Check if input simspec is for a continuum flat lamp instead of science
    #- This does not convolve to per-fiber resolution
    if args.simspec:
        if simspec.flavor == 'flat':
            log.info("Simulating flat lamp exposure")
            for i, camera in enumerate(qsim.instrument.cameras):
                channel = camera.name  #- from simspec, b/r/z not b0/r1/z9
                assert camera.output_wavelength.unit == u.Angstrom
                num_pixels = len(waves[channel])

                phot = list()
                for j in range(10):
                    cn = camera.name + str(j)
                    if cn in simspec.cameras:
                        camwave = simspec.cameras[cn].wave
                        dw = np.gradient(camwave)
                        phot.append(simspec.cameras[cn].phot)

                if len(phot) == 0:
                    raise RuntimeError(
                        'Unable to find a {} camera in input simspec'.format(
                            camera))
                else:
                    phot = np.vstack(phot)

                meanspec = resample_flux(waves[channel], camwave,
                                         np.average(phot / dw, axis=0))

                fiberflat = random_state.normal(loc=1.0,
                                                scale=1.0 / np.sqrt(meanspec),
                                                size=(nspec, num_pixels))
                ivar = np.tile(meanspec, [nspec, 1])
                mask = np.zeros((simspec.nspec, num_pixels), dtype=np.uint32)

                for kk in range((args.nspec + args.nstart - 1) //
                                args.n_fibers + 1):
                    camera = channel + str(kk)
                    outfile = desispec.io.findfile('fiberflat', NIGHT, EXPID,
                                                   camera)
                    start = max(args.n_fibers * kk, args.nstart)
                    end = min(args.n_fibers * (kk + 1), nmax)

                    if (args.spectrograph <= kk):
                        log.info(
                            "Writing files for channel:{}, spectrograph:{}, spectra:{} to {}"
                            .format(channel, kk, start, end))

                    ff = FiberFlat(waves[channel],
                                   fiberflat[start:end, :],
                                   ivar[start:end, :],
                                   mask[start:end, :],
                                   meanspec,
                                   header=dict(CAMERA=camera))
                    write_fiberflat(outfile, ff)
                    filePath = desispec.io.findfile("fiberflat", NIGHT, EXPID,
                                                    camera)
                    log.info("Wrote file {}".format(filePath))

            sys.exit(0)

    # Repeat the simulation for all spectra
    fluxunits = 1e-17 * u.erg / (u.s * u.cm**2 * u.Angstrom)
    for j in range(args.nspec):

        thisobjtype = objtype[j]
        sys.stdout.flush()
        if flavor == 'arc':
            qsim.source.update_in('Quickgen source {0}'.format, 'perfect',
                                  wavelengths * u.Angstrom,
                                  spectra * fluxunits)
        else:
            qsim.source.update_in('Quickgen source {0}'.format(j),
                                  thisobjtype.lower(),
                                  wavelengths * u.Angstrom,
                                  spectra[j, :] * fluxunits)
        qsim.source.update_out()

        qsim.simulate()
        qsim.generate_random_noise(random_state)

        for i, output in enumerate(qsim.camera_output):
            assert output['observed_flux'].unit == 1e17 * fluxunits
            # Extract the simulation results needed to create our uncalibrated
            # frame output file.
            num_pixels = len(output)
            nobj[j, i, :num_pixels] = output['num_source_electrons'][:, 0]
            nsky[j, i, :num_pixels] = output['num_sky_electrons'][:, 0]
            nivar[j, i, :num_pixels] = 1.0 / output['variance_electrons'][:, 0]

            # Get results for our flux-calibrated output file.
            cframe_observedflux[
                j, i, :num_pixels] = 1e17 * output['observed_flux'][:, 0]
            cframe_ivar[
                j,
                i, :num_pixels] = 1e-34 * output['flux_inverse_variance'][:, 0]

            # Fill brick arrays from the results.
            camera = output.meta['name']
            trueflux[camera][j][:] = 1e17 * output['observed_flux'][:, 0]
            noisyflux[camera][j][:] = 1e17 * (
                output['observed_flux'][:, 0] +
                output['flux_calibration'][:, 0] *
                output['random_noise_electrons'][:, 0])
            obsivar[camera][j][:] = 1e-34 * output['flux_inverse_variance'][:,
                                                                            0]

            # Use the same noise realization in the cframe and frame, without any
            # additional noise from sky subtraction for now.
            frame_rand_noise[
                j, i, :num_pixels] = output['random_noise_electrons'][:, 0]
            cframe_rand_noise[j, i, :num_pixels] = 1e17 * (
                output['flux_calibration'][:, 0] *
                output['random_noise_electrons'][:, 0])

            # The sky output file represents a model fit to ~40 sky fibers.
            # We reduce the variance by a factor of 25 to account for this and
            # give the sky an independent (Gaussian) noise realization.
            sky_ivar[
                j,
                i, :num_pixels] = 25.0 / (output['variance_electrons'][:, 0] -
                                          output['num_source_electrons'][:, 0])
            sky_rand_noise[j, i, :num_pixels] = random_state.normal(
                scale=1.0 / np.sqrt(sky_ivar[j, i, :num_pixels]),
                size=num_pixels)

    armName = {"b": 0, "r": 1, "z": 2}
    for channel in 'brz':

        #Before writing, convert from counts/bin to counts/A (as in Pixsim output)
        #Quicksim Default:
        #FLUX - input spectrum resampled to this binning; no noise added [1e-17 erg/s/cm2/s/Ang]
        #COUNTS_OBJ - object counts in 0.5 Ang bin
        #COUNTS_SKY - sky counts in 0.5 Ang bin

        num_pixels = len(waves[channel])
        dwave = np.gradient(waves[channel])
        nobj[:, armName[channel], :num_pixels] /= dwave
        frame_rand_noise[:, armName[channel], :num_pixels] /= dwave
        nivar[:, armName[channel], :num_pixels] *= dwave**2
        nsky[:, armName[channel], :num_pixels] /= dwave
        sky_rand_noise[:, armName[channel], :num_pixels] /= dwave
        sky_ivar[:, armName[channel], :num_pixels] /= dwave**2

        # Now write the outputs in DESI standard file system. None of the output file can have more than 500 spectra

        # Looping over spectrograph
        for ii in range((args.nspec + args.nstart - 1) // args.n_fibers + 1):

            start = max(args.n_fibers * ii,
                        args.nstart)  # first spectrum for a given spectrograph
            end = min(args.n_fibers * (ii + 1),
                      nmax)  # last spectrum for the spectrograph

            if (args.spectrograph <= ii):
                camera = "{}{}".format(channel, ii)
                log.info(
                    "Writing files for channel:{}, spectrograph:{}, spectra:{} to {}"
                    .format(channel, ii, start, end))
                num_pixels = len(waves[channel])

                # Write frame file
                framefileName = desispec.io.findfile("frame", NIGHT, EXPID,
                                                     camera)

                frame_flux=nobj[start:end,armName[channel],:num_pixels]+ \
                nsky[start:end,armName[channel],:num_pixels] + \
                frame_rand_noise[start:end,armName[channel],:num_pixels]
                frame_ivar = nivar[start:end, armName[channel], :num_pixels]

                sh1 = frame_flux.shape[
                    0]  # required for slicing the resolution metric, resolusion matrix has (nspec,ndiag,wave)
                # for example if nstart =400, nspec=150: two spectrographs:
                # 400-499=> 0 spectrograph, 500-549 => 1
                if (args.nstart == start):
                    resol = resolution[channel][:sh1, :, :]
                else:
                    resol = resolution[channel][-sh1:, :, :]

                # must create desispec.Frame object
                frame=Frame(waves[channel], frame_flux, frame_ivar,\
                    resolution_data=resol, spectrograph=ii, \
                    fibermap=fibermap[start:end], \
                    meta=dict(CAMERA=camera, FLAVOR=simspec.flavor) )
                desispec.io.write_frame(framefileName, frame)

                framefilePath = desispec.io.findfile("frame", NIGHT, EXPID,
                                                     camera)
                log.info("Wrote file {}".format(framefilePath))

                if args.frameonly or simspec.flavor == 'arc':
                    continue

                # Write cframe file
                cframeFileName = desispec.io.findfile("cframe", NIGHT, EXPID,
                                                      camera)
                cframeFlux = cframe_observedflux[
                    start:end,
                    armName[channel], :num_pixels] + cframe_rand_noise[
                        start:end, armName[channel], :num_pixels]
                cframeIvar = cframe_ivar[start:end,
                                         armName[channel], :num_pixels]

                # must create desispec.Frame object
                cframe = Frame(waves[channel], cframeFlux, cframeIvar, \
                    resolution_data=resol, spectrograph=ii,
                    fibermap=fibermap[start:end],
                    meta=dict(CAMERA=camera, FLAVOR=simspec.flavor, NIGHT=NIGHT, EXPID=EXPID, TILEID=TILEID) )
                desispec.io.frame.write_frame(cframeFileName, cframe)

                cframefilePath = desispec.io.findfile("cframe", NIGHT, EXPID,
                                                      camera)
                log.info("Wrote file {}".format(cframefilePath))

                # Write sky file
                skyfileName = desispec.io.findfile("sky", NIGHT, EXPID, camera)
                skyflux=nsky[start:end,armName[channel],:num_pixels] + \
                sky_rand_noise[start:end,armName[channel],:num_pixels]
                skyivar = sky_ivar[start:end, armName[channel], :num_pixels]
                skymask = np.zeros(skyflux.shape, dtype=np.uint32)

                # must create desispec.Sky object
                skymodel = SkyModel(waves[channel],
                                    skyflux,
                                    skyivar,
                                    skymask,
                                    header=dict(CAMERA=camera))
                desispec.io.sky.write_sky(skyfileName, skymodel)

                skyfilePath = desispec.io.findfile("sky", NIGHT, EXPID, camera)
                log.info("Wrote file {}".format(skyfilePath))

                # Write calib file
                calibVectorFile = desispec.io.findfile("calib", NIGHT, EXPID,
                                                       camera)
                flux = cframe_observedflux[start:end,
                                           armName[channel], :num_pixels]
                phot = nobj[start:end, armName[channel], :num_pixels]
                calibration = np.zeros_like(phot)
                jj = (flux > 0)
                calibration[jj] = phot[jj] / flux[jj]

                #- TODO: what should calibivar be?
                #- For now, model it as the noise of combining ~10 spectra
                calibivar = 10 / cframe_ivar[start:end,
                                             armName[channel], :num_pixels]
                #mask=(1/calibivar>0).astype(int)??
                mask = np.zeros(calibration.shape, dtype=np.uint32)

                # write flux calibration
                fluxcalib = FluxCalib(waves[channel], calibration, calibivar,
                                      mask)
                write_flux_calibration(calibVectorFile, fluxcalib)

                calibfilePath = desispec.io.findfile("calib", NIGHT, EXPID,
                                                     camera)
                log.info("Wrote file {}".format(calibfilePath))
Beispiel #50
0
def write_raw(filename, rawdata, header, camera=None, primary_header=None):
    '''
    Write raw pixel data to a DESI raw data file

    Args:
        filename : file name to write data; if this exists, append a new HDU
        rawdata : 2D ndarray of raw pixel data including overscans
        header : dict-like object or fits.Header with keywords
            CCDSECx, BIASSECx, DATASECx where x=1,2,3, or 4

    Options:
        camera : b0, r1 .. z9 - override value in header
        primary_header : header to write in HDU0 if filename doesn't yet exist

    The primary utility of this function over raw fits calls is to ensure
    that all necessary keywords are present before writing the file.
    CCDSECx, BIASSECx, DATASECx where x=1,2,3, or 4
    DATE-OBS, GAINx and RDNOISEx will generate a non-fatal warning if missing
    '''
    log = get_logger()

    header = desispec.io.util.fitsheader(header)
    primary_header = desispec.io.util.fitsheader(primary_header)

    if rawdata.dtype not in (np.int16, np.int32, np.int64):
        message = 'dtype {} not supported for raw data'.format(rawdata.dtype)
        log.fatal(message)
        raise ValueError(message)

    fail_message = ''
    for required_key in ['DOSVER', 'FEEVER', 'DETECTOR']:
        if required_key not in primary_header:
            if required_key in header:
                primary_header[required_key] = header[required_key]
            else:
                fail_message = fail_message + \
                    'Keyword {} must be in header or primary_header\n'.format(required_key)
    if fail_message != '':
        raise ValueError(fail_message)

    #- Check required keywords before writing anything
    missing_keywords = list()
    if camera is None and 'CAMERA' not in header:
        log.error("Must provide camera keyword or header['CAMERA']")
        missing_keywords.append('CAMERA')

    for amp in ['1', '2', '3', '4']:
        for prefix in ['CCDSEC', 'BIASSEC', 'DATASEC']:
            keyword = prefix + amp
            if keyword not in header:
                log.error('Missing keyword ' + keyword)
                missing_keywords.append(keyword)

    #- Missing DATE-OBS is warning but not error
    if 'DATE-OBS' not in primary_header:
        if 'DATE-OBS' in header:
            primary_header['DATE-OBS'] = header['DATE-OBS']
        else:
            log.warning('missing keyword DATE-OBS')

    #- Missing GAINx is warning but not error
    for amp in ['1', '2', '3', '4']:
        keyword = 'GAIN' + amp
        if keyword not in header:
            log.warning('Gain keyword {} missing; using 1.0'.format(keyword))
            header[keyword] = 1.0

    #- Missing RDNOISEx is warning but not error
    for amp in ['1', '2', '3', '4']:
        keyword = 'RDNOISE' + amp
        if keyword not in header:
            log.warning('Readnoise keyword {} missing'.format(keyword))

    #- Stop if any keywords are missing
    if len(missing_keywords) > 0:
        raise KeyError('missing required keywords {}'.format(missing_keywords))

    #- Set EXTNAME=camera
    if camera is not None:
        header['CAMERA'] = camera.lower()
        extname = camera.upper()
    else:
        if header['CAMERA'] != header['CAMERA'].lower():
            log.warning('Converting CAMERA {} to lowercase'.format(
                header['CAMERA']))
            header['CAMERA'] = header['CAMERA'].lower()
        extname = header['CAMERA'].upper()

    header['INHERIT'] = True

    #- fits.CompImageHDU doesn't know how to fill in default keywords, so
    #- temporarily generate an uncompressed HDU to get those keywords
    header = fits.ImageHDU(rawdata, header=header, name=extname).header

    #- Bizarrely, compression of 64-bit integers isn't supported.
    #- downcast to 32-bit if that won't lose precision.
    #- Real raw data should be 32-bit or 16-bit anyway
    if rawdata.dtype == np.int64:
        if np.max(np.abs(rawdata)) < 2**31:
            rawdata = rawdata.astype(np.int32)

    if rawdata.dtype in (np.int16, np.int32):
        dataHDU = fits.CompImageHDU(rawdata, header=header, name=extname)
    elif rawdata.dtype == np.int64:
        log.warning(
            'Image compression not supported for 64-bit; writing uncompressed')
        dataHDU = fits.ImageHDU(rawdata, header=header, name=extname)
    else:
        log.error("How did we get this far with rawdata dtype {}?".format(
            rawdata.dtype))
        dataHDU = fits.ImageHDU(rawdata, header=header, name=extname)

    #- Actually write or update the file
    if os.path.exists(filename):
        hdus = fits.open(filename, mode='append', memmap=False)
        if extname in hdus:
            hdus.close()
            raise ValueError('Camera {} already in {}'.format(
                camera, filename))
        else:
            hdus.append(dataHDU)
            hdus.flush()
            hdus.close()
    else:
        hdus = fits.HDUList()
        add_dependencies(primary_header)
        hdus.append(fits.PrimaryHDU(None, header=primary_header))
        hdus.append(dataHDU)
        hdus.writeto(filename)
Beispiel #51
0
def get_exp2healpix_map(nights=None, specprod_dir=None, nside=64, comm=None):
    '''
    Returns table with columns NIGHT EXPID SPECTRO HEALPIX NTARGETS

    Options:
        nights: list of YEARMMDD to scan for exposures
        specprod_dir: override $DESI_SPECTRO_REDUX/$SPECPROD
        nside: healpix nside, must be power of 2
        comm: MPI communicator

    Note: This could be replaced by a DB query when the production DB exists.
    '''
    log = get_logger()
    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    if specprod_dir is None:
        specprod_dir = io.specprod_root()

    if nights is None and rank == 0:
        nights = io.get_nights(specprod_dir=specprod_dir)

    if comm:
        nights = comm.bcast(nights, root=0)

    #-----
    #- Distribute nights over ranks, scanning their exposures to build
    #- map of exposures -> healpix

    #- Rows to add to the output table
    rows = list()

    #- for tracking exposures that we've already mapped in a different band
    night_expid_spectro = set()

    for night in nights[rank::size]:
        night = str(night)
        nightdir = os.path.join(specprod_dir, 'exposures', night)
        for expid in io.get_exposures(night,
                                      specprod_dir=specprod_dir,
                                      raw=False):
            tmpframe = io.findfile('cframe',
                                   night,
                                   expid,
                                   'r0',
                                   specprod_dir=specprod_dir)
            expdir = os.path.split(tmpframe)[0]
            cframefiles = sorted(glob.glob(expdir + '/cframe*.fits'))
            for filename in cframefiles:
                #- parse 'path/night/expid/cframe-r0-12345678.fits'
                camera = os.path.basename(filename).split('-')[1]
                channel, spectro = camera[0], int(camera[1])

                #- skip if we already have this expid/spectrograph
                if (night, expid, spectro) in night_expid_spectro:
                    continue
                else:
                    night_expid_spectro.add((night, expid, spectro))

                log.debug('Rank {} mapping {} {}'.format(
                    rank, night, os.path.basename(filename)))
                sys.stdout.flush()

                #- Determine healpix, allowing for NaN
                columns = ['TARGET_RA', 'TARGET_DEC']
                fibermap = fitsio.read(filename, 'FIBERMAP', columns=columns)
                ra, dec = fibermap['TARGET_RA'], fibermap['TARGET_DEC']
                ok = ~np.isnan(ra) & ~np.isnan(dec)
                ra, dec = ra[ok], dec[ok]
                allpix = desimodel.footprint.radec2pix(nside, ra, dec)

                #- Add rows for final output
                for pix, ntargets in sorted(Counter(allpix).items()):
                    rows.append((night, expid, spectro, pix, ntargets))

    #- Collect rows from individual ranks back to rank 0
    if comm:
        rank_rows = comm.gather(rows, root=0)
        if rank == 0:
            rows = list()
            for r in rank_rows:
                rows.extend(r)
        else:
            rows = None

        rows = comm.bcast(rows, root=0)

    #- Create the final output table
    exp2healpix = np.array(rows,
                           dtype=[('NIGHT', 'i4'), ('EXPID', 'i8'),
                                  ('SPECTRO', 'i4'), ('HEALPIX', 'i8'),
                                  ('NTARGETS', 'i8')])

    return exp2healpix
def parse(options=None):
    """
    Can change night or number of spectra to be simulated and delete all output of test
    Won't overwrite exisiting data unless overwrite argument provided

    QuickLook data read from $QL_SPEC_DATA
    QuickLook output written to $QL_SPEC_REDUX

    Environment Variable check included here
    """
    parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--night',type=str,default='20160728',help='night to be simulated')
    parser.add_argument('--nspec',type=int,default=5,help='number of spectra to be simulated, starting from first')
    parser.add_argument('--overwrite', action='store_true', help='overwrite existing files')
    parser.add_argument('--delete', action='store_true', help='delete all files generated by this test')

    if options is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(options)

    log = logging.get_logger()
    log.setLevel(logging.DEBUG)
    missing_env = False

    if 'DESI_BASIS_TEMPLATES' not in os.environ:
        log.warning('missing $DESI_BASIS_TEMPLATES needed for simulating spectra'.format(name))
        missing_env = True

    if not os.path.isdir(os.getenv('DESI_BASIS_TEMPLATES')):
        log.warning('missing $DESI_BASIS_TEMPLATES directory')
        log.warning('e.g. see NERSC:/project/projectdirs/desi/spectro/templates/basis_templates/v1.0')
        missing_env = True

    for name in (
        'DESI_SPECTRO_SIM', 'PIXPROD', 'DESIMODEL'):
        if name not in os.environ:
            log.warning("missing ${}".format(name))
            missing_env = True

    if 'QL_SPEC_REDUX' not in os.environ:
        log.warning("missing ${}".format('QL_SPEC_REDUX'))
        missing_env = True

    if 'QL_SPEC_DATA' not in os.environ:
        log.warning("missing ${}".format('QL_SPEC_DATA'))
        missing_env = True

    if missing_env:
        log.warning("Why are these needed?")
        log.warning("    Simulations written to $DESI_SPECTRO_SIM/$PIXPROD")
        log.warning("    Raw data read from $QL_SPEC_DATA")
        log.warning("    Spectro/QuickLook pipeline output written to $QL_SPEC_REDUX")
        log.warning("    PSF files are found in $DESIMODEL")
        log.warning("    Templates are read from $DESI_BASIS_TEMPLATES")

    #- Wait until end to raise exception so that we report everything that
    #- is missing before actually failing
    if missing_env:
        log.critical("missing env vars; exiting without running simulations or quicklook pipeline")
        sys.exit(1)

    sim_dir = os.path.join(os.environ['DESI_SPECTRO_SIM'],os.environ['PIXPROD'],args.night)
    data_dir = os.path.join(os.environ['QL_SPEC_DATA'],args.night)
    output_dir = os.environ['QL_SPEC_REDUX']

    if args.overwrite:
        if os.path.exists(sim_dir):
            sim_files = os.listdir(sim_dir)
            for file in range(len(sim_files)):
                sim_file = os.path.join(sim_dir,sim_files[file])
                os.remove(sim_file)
            os.rmdir(sim_dir)
        if os.path.exists(data_dir):
            data_files = os.listdir(data_dir)
            for file in range(len(data_files)):
                data_file = os.path.join(data_dir,data_files[file])
                os.remove(data_file)
            os.rmdir(data_dir)
        if os.path.exists(output_dir):
            exp_dir = os.path.join(output_dir,'exposures',args.night)
            calib_dir = os.path.join(output_dir,'calib2d',args.night)
            if os.path.exists(exp_dir):
                id_dir = os.path.join(exp_dir,'00000002')
                if os.path.exists(id_dir):
                    id_files = os.listdir(id_dir)
                    for file in range(len(id_files)):
                        id_file = os.path.join(id_dir,id_files[file])
                        os.remove(id_file)
                    os.rmdir(id_dir)
                exp_files = os.listdir(exp_dir)
                for file in range(len(exp_files)):
                    exp_file = os.path.join(exp_dir,exp_files[file])
                    os.remove(exp_file)
                os.rmdir(exp_dir)
            if os.path.exists(calib_dir):
                calib_files = os.listdir(calib_dir)
                for file in range(len(calib_files)):
                    calib_file = os.path.join(calib_dir,calib_files[file])
                    os.remove(calib_file)
                os.rmdir(calib_dir)            

    else:
        if os.path.exists(sim_dir) or os.path.exists(data_dir) or os.path.exists(output_dir):
            raise RuntimeError('Files already exist for this night! Can overwrite or change night if necessary')

    return args
Beispiel #53
0
#!/usr/bin/env python

import argparse
from timedomain.filters import *
from timedomain.iterators import *
from timedomain.sp_utils import *
from timedomain.fs_utils import *
import timedomain.config as config
import sys
from astropy.table import Table
from desiutil.log import get_logger, DEBUG
log = get_logger(DEBUG)


def main(args):
    """ Main entry point of the app """
    print("Start ", args)
    logic = getattr(sys.modules[__name__], args.logic)
    iterator_ = getattr(sys.modules[__name__], args.iterator)

    prunelogic = args.logic[0:args.logic.find("Logic")]

    ### Get the tile and array from the arguments

    if args.obsdates_tilenumbers != None:
        obsdates_tilenumbers_str = args.obsdates_tilenumbers
        obsdates_tilenumbers = np.chararray((len(obsdates_tilenumbers_str), 2),
                                            itemsize=10,
                                            unicode=True)
        for i in range(len(obsdates_tilenumbers_str)):
            obsdates_tilenumbers[i, :] = obsdates_tilenumbers_str[i].split('|')
Beispiel #54
0
def get_nodes_per_exp(nnodes,nexposures,ncameras,user_nodes_per_comm_exp=None):
    """
    Calculate how many nodes to use per exposure

    Args:
        nnodes: number of nodes in MPI COMM_WORLD (not number of ranks)
        nexposures: number of exposures to process
        ncameras: number of cameras per exposure
        user_nodes_per_comm_exp (int, optional): user override of number of
            nodes to use; used to check requirements

    Returns number of nodes to include in sub-communicators used to process
    individual exposures

    Notes:
        * Uses the largest number of nodes per exposure that will still
          result in efficient node usage
        * requires that (nexposures*ncameras) / nnodes = int
        * the derived nodes_per_comm_exp * nexposures / nodes = int
        * See desisim.test.test_pixsim.test_get_nodes_per_exp() for examples
        * if user_nodes_per_comm_exp is given, requires that
          GreatestCommonDivisor(nnodes, ncameras) / user_nodes_per_comm_exp = int
    """

    from math import gcd
    import desiutil.log as logging
    log = logging.get_logger()
    log.setLevel(logging.INFO)

    #check if nframes is evenly divisible by nnodes
    nframes = ncameras*nexposures
    if nframes % nnodes !=0:
        ### msg=("nframes {} must be evenly divisible by nnodes {}, try again".format(nframes, nnodes))
        ### raise ValueError(msg)
        msg=("nframes {} is not evenly divisible by nnodes {}; packing will be inefficient".format(nframes, nnodes))
        log.warning(msg)
    else:
        log.debug("nframes {} is evenly divisible by nnodes {}, check passed".format(nframes, nnodes))

    #find greatest common divisor between nnodes and ncameras
    #greatest common divisor = greatest common factor
    #we use python's built in gcd
    greatest_common_factor=gcd(nnodes,ncameras)
    #the greatest common factor must be greater than one UNLESS we are on one node
    if nnodes > 1:
        if greatest_common_factor == 1:
            msg=("greatest common factor {} between nnodes {} and nframes {} must be larger than one, try again".format(greatest_common_factor, nnodes, nframes))
            raise ValueError(msg)
        else:
            log.debug("greatest common factor {} between nnodes {} and nframes {} is greater than one, check passed".format(greatest_common_factor, nnodes, nframes))

    #check to make sure the user hasn't specified a really asinine value of user_nodes_per_comm_exp
    if user_nodes_per_comm_exp is not None:
        if greatest_common_factor % user_nodes_per_comm_exp !=0:
            msg=("user-specified value of user_nodes_per_comm_exp {} is bad, try again".format(user_nodes_per_comm_exp))
            raise ValueError(msg)
        else:
            log.debug("user-specified value of user_nodes_per_comm_exp {} is good, check passed".format(user_nodes_per_comm_exp))
            nodes_per_comm_exp=user_nodes_per_comm_exp
    #if the user didn't specify anything, use the greatest common factor
    if user_nodes_per_comm_exp is None:
        nodes_per_comm_exp=greatest_common_factor

    #finally check to make sure exposures*gcf/nnodes is an integer to avoid inefficient node use
    if (nexposures*nodes_per_comm_exp) % nnodes != 0:
        ### msg=("nexposures {} * nodes_per_comm_exp {} does not divide evenly into nnodes {}, try again".format(nexposures, nodes_per_comm_exp, nnodes))
        ### raise ValueError(msg)
        msg=("nexposures {} * nodes_per_comm_exp {} does not divide evenly into nnodes {}; packing will be inefficient".format(nexposures, nodes_per_comm_exp, nnodes))
        log.warning(msg)
    else:
        log.debug("nexposures {} * nodes_per_comm_exp {} divides evenly into nnodes {}, check passed".format(nexposures, nodes_per_comm_exp, nnodes))


    return nodes_per_comm_exp
Beispiel #55
0
def subtract_sky(frame,
                 skymodel,
                 throughput_correction=False,
                 default_throughput_correction=1.):
    """Subtract skymodel from frame, altering frame.flux, .ivar, and .mask

    Args:
        frame : desispec.Frame object
        skymodel : desispec.SkyModel object

    Option:
        throughput_correction : if True, fit for an achromatic throughput correction. This is to absorb variations of Focal Ratio Degradation with fiber flexure.
        default_throughput_correction : float, default value of correction if the fit on sky lines failed.
    """
    assert frame.nspec == skymodel.nspec
    assert frame.nwave == skymodel.nwave

    log = get_logger()
    log.info("starting")

    # check same wavelength, die if not the case
    if not np.allclose(frame.wave, skymodel.wave):
        message = "frame and sky not on same wavelength grid"
        log.error(message)
        raise ValueError(message)

    if throughput_correction:
        # need to fit for a multiplicative factor of the sky model
        # before subtraction
        # we are going to use a set of bright sky lines,
        # and fit a multiplicative factor + background around
        # each of them individually, and then combine the results
        # with outlier rejection in case a source emission line
        # coincides with one of the sky lines.

        # it's more robust to have a hardcoded set of sky lines here
        # these are all the sky lines with a flux >5% of the max flux
        # except in b where we add an extra weaker line at 5199.4A
        skyline = np.array([
            5199.4, 5578.4, 5656.4, 5891.4, 5897.4, 6302.4, 6308.4, 6365.4,
            6500.4, 6546.4, 6555.4, 6618.4, 6663.4, 6679.4, 6690.4, 6765.4,
            6831.4, 6836.4, 6865.4, 6925.4, 6951.4, 6980.4, 7242.4, 7247.4,
            7278.4, 7286.4, 7305.4, 7318.4, 7331.4, 7343.4, 7360.4, 7371.4,
            7394.4, 7404.4, 7440.4, 7526.4, 7714.4, 7719.4, 7752.4, 7762.4,
            7782.4, 7796.4, 7810.4, 7823.4, 7843.4, 7855.4, 7862.4, 7873.4,
            7881.4, 7892.4, 7915.4, 7923.4, 7933.4, 7951.4, 7966.4, 7982.4,
            7995.4, 8016.4, 8028.4, 8064.4, 8280.4, 8284.4, 8290.4, 8298.4,
            8301.4, 8313.4, 8346.4, 8355.4, 8367.4, 8384.4, 8401.4, 8417.4,
            8432.4, 8454.4, 8467.4, 8495.4, 8507.4, 8627.4, 8630.4, 8634.4,
            8638.4, 8652.4, 8657.4, 8662.4, 8667.4, 8672.4, 8677.4, 8683.4,
            8763.4, 8770.4, 8780.4, 8793.4, 8829.4, 8835.4, 8838.4, 8852.4,
            8870.4, 8888.4, 8905.4, 8922.4, 8945.4, 8960.4, 8990.4, 9003.4,
            9040.4, 9052.4, 9105.4, 9227.4, 9309.4, 9315.4, 9320.4, 9326.4,
            9340.4, 9378.4, 9389.4, 9404.4, 9422.4, 9442.4, 9461.4, 9479.4,
            9505.4, 9521.4, 9555.4, 9570.4, 9610.4, 9623.4, 9671.4, 9684.4,
            9693.4, 9702.4, 9714.4, 9722.4, 9740.4, 9748.4, 9793.4, 9802.4,
            9814.4, 9820.4
        ])

        sw = []
        swf = []
        sws = []
        sws2 = []
        swsf = []

        # half width of wavelength region around each sky line
        # larger values give a better statistical precision
        # but also a larger sensitivity to source features
        # best solution on one dark night exposure obtained with
        # a half width of 4A.
        hw = 4  #A
        tivar = frame.ivar
        if frame.mask is not None:
            tivar *= (frame.mask == 0)
            tivar *= (skymodel.ivar > 0)

        # we precompute the quantities needed to fit each sky line + continuum
        # the sky "line profile" is the actual sky model
        # and we consider an additive constant
        for line in skyline:
            if line <= frame.wave[0] or line >= frame.wave[-1]: continue
            ii = np.where((frame.wave >= line - hw)
                          & (frame.wave <= line + hw))[0]
            if ii.size < 2: continue
            sw.append(np.sum(tivar[:, ii], axis=1))
            swf.append(np.sum(tivar[:, ii] * frame.flux[:, ii], axis=1))
            swsf.append(
                np.sum(tivar[:, ii] * frame.flux[:, ii] * skymodel.flux[:, ii],
                       axis=1))
            sws.append(np.sum(tivar[:, ii] * skymodel.flux[:, ii], axis=1))
            sws2.append(np.sum(tivar[:, ii] * skymodel.flux[:, ii]**2, axis=1))

        nlines = len(sw)

        for fiber in range(frame.flux.shape[0]):

            # we solve the 2x2 linear system for each fiber and sky line
            # and save the results for each fiber

            coef = []  # list of scale values
            var = []  # list of variance on scale values
            for line in range(nlines):
                if sw[line][fiber] <= 0: continue
                A = np.array([[sw[line][fiber], sws[line][fiber]],
                              [sws[line][fiber], sws2[line][fiber]]])
                B = np.array([swf[line][fiber], swsf[line][fiber]])
                try:
                    Ai = np.linalg.inv(A)
                    X = Ai.dot(B)
                    coef.append(
                        X[1]
                    )  # the scale coef (marginalized over cst background)
                    var.append(Ai[1, 1])
                except:
                    pass

            if len(coef) == 0:
                log.warning("cannot corr. throughput. for fiber %d" % fiber)
                continue

            coef = np.array(coef)
            var = np.array(var)
            ivar = (var > 0) / (var + (var == 0) + 0.005**2)
            ivar_for_outliers = (var > 0) / (var + (var == 0) + 0.02**2)

            # loop for outlier rejection
            failed = False
            for loop in range(50):
                a = np.sum(ivar)
                if a <= 0:
                    log.warning(
                        "cannot corr. throughput. ivar=0 everywhere on sky lines for fiber %d"
                        % fiber)
                    failed = True
                    break

                mcoef = np.sum(ivar * coef) / a
                mcoeferr = 1 / np.sqrt(a)

                nsig = 3.
                chi2 = ivar_for_outliers * (coef - mcoef)**2
                worst = np.argmax(chi2)
                if chi2[worst] > nsig**2 * np.median(
                        chi2[chi2 > 0]):  # with rough scaling of errors
                    #log.debug("discard a bad measurement for fiber %d"%(fiber))
                    ivar[worst] = 0
                    ivar_for_outliers[worst] = 0
                else:
                    break

            if failed:
                continue

            log.info(
                "fiber #%03d throughput corr = %5.4f +- %5.4f (mean fiber flux=%f)"
                % (fiber, mcoef, mcoeferr, np.median(frame.flux[fiber])))
            '''
            if np.abs(mcoef)>0.01 :
                
                print(fiber,"mean coef=",mcoef,"all coef=",coef)
                print(fiber,"all err=",np.sqrt(var))
                print(fiber,"mean coef=",mcoef,"selected coef=",coef[ivar>0])
                print(fiber,"select err=",np.sqrt(var[ivar>0]))
                import matplotlib.pyplot as plt
                x=np.arange(coef.size)
                plt.errorbar(x,coef,np.sqrt(var),fmt="o")
                plt.errorbar(x[ivar>0],coef[ivar>0],np.sqrt(var[ivar>0]),fmt="o")
                plt.axhline(0.)
                plt.axhline(mcoef)
                plt.ylim(-0.11,0.11)
                plt.grid()
                plt.show()
            '''

            if mcoeferr > 0.01:
                log.warning(
                    "throughput corr error = %5.4f is too large for fiber #%03d, do not apply correction"
                    % (mcoeferr, fiber))
                throughput_correction_value = default_throughput_correction
            else:
                throughput_correction_value = mcoef

            # apply this correction to the sky model even if we have not fit it (default can be 1 or 0)
            skymodel.flux[fiber] *= throughput_correction_value

    frame.flux -= skymodel.flux
    frame.ivar = util.combine_ivar(frame.ivar, skymodel.ivar)
    frame.mask |= skymodel.mask

    log.info("done")
Beispiel #56
0
def write_image(outfile, image, meta=None):
    """Writes image object to outfile

    Args:
        outfile : output file string
        image : desispec.image.Image object
            (or any object with 2D array attributes image, ivar, mask)

    Optional:
        meta : dict-like object with metadata key/values (e.g. FITS header)
    """

    log = get_logger()
    if meta is not None:
        hdr = fitsheader(meta)
    else:
        hdr = fitsheader(image.meta)

    add_dependencies(hdr)

    #- Work around fitsio>1.0 writing blank keywords, e.g. on 20191212
    for key in hdr.keys():
        if type(hdr[key]) == fits.card.Undefined:
            log.warning('Setting blank keyword {} to None'.format(key))
            hdr[key] = None

    outdir = os.path.dirname(os.path.abspath(outfile))
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

    hx = fits.HDUList()
    hdu = fits.ImageHDU(image.pix.astype(np.float32), name='IMAGE', header=hdr)
    if 'CAMERA' not in hdu.header:
        hdu.header.append(
            ('CAMERA', image.camera.lower(), 'Spectrograph Camera'))

    if 'RDNOISE' not in hdu.header and np.isscalar(image.readnoise):
        hdu.header.append(
            ('RDNOISE', image.readnoise, 'Read noise [RMS electrons/pixel]'))

    hx.append(hdu)
    hx.append(fits.ImageHDU(image.ivar.astype(np.float32), name='IVAR'))
    hx.append(fits.CompImageHDU(image.mask.astype(np.int16), name='MASK'))
    if not np.isscalar(image.readnoise):
        hx.append(
            fits.ImageHDU(image.readnoise.astype(np.float32),
                          name='READNOISE'))

    if hasattr(image, 'fibermap'):
        if isinstance(image.fibermap, Table):
            fmhdu = fits.convenience.table_to_hdu(image.fibermap)
            fmhdu.name = 'FIBERMAP'
        else:
            fmhdu = fits.BinTableHDU(image.fibermap, name='FIBERMAP')

        hx.append(fmhdu)

    hx.writeto(outfile + '.tmp', overwrite=True, checksum=True)
    os.rename(outfile + '.tmp', outfile)

    return outfile
Beispiel #57
0
def compute_polynomial_times_sky(frame,
                                 nsig_clipping=4.,
                                 max_iterations=30,
                                 model_ivar=False,
                                 add_variance=True,
                                 angular_variation_deg=1,
                                 chromatic_variation_deg=1):
    """Compute a sky model.
    
    Sky[fiber,i] = R[fiber,i,j] Polynomial(x[fiber],y[fiber],wavelength[j]) Flux[j]
    
    Input flux are expected to be flatfielded!
    We don't check this in this routine.

    Args:
        frame : Frame object, which includes attributes
          - wave : 1D wavelength grid in Angstroms
          - flux : 2D flux[nspec, nwave] density
          - ivar : 2D inverse variance of flux
          - mask : 2D inverse mask flux (0=good)
          - resolution_data : 3D[nspec, ndiag, nwave]  (only sky fibers)
        nsig_clipping : [optional] sigma clipping value for outlier rejection

    Optional:
        max_iterations : int , number of iterations
        model_ivar : replace ivar by a model to avoid bias due to correlated flux and ivar. this has a negligible effect on sims.
        add_variance : evaluate calibration error and add this to the sky model variance
        
    returns SkyModel object with attributes wave, flux, ivar, mask
    """

    log = get_logger()
    log.info("starting")

    # Grab sky fibers on this frame
    skyfibers = np.where(frame.fibermap['OBJTYPE'] == 'SKY')[0]
    assert np.max(skyfibers) < 500  #- indices, not fiber numbers

    nwave = frame.nwave
    nfibers = len(skyfibers)

    current_ivar = frame.ivar[skyfibers].copy() * (frame.mask[skyfibers] == 0)
    flux = frame.flux[skyfibers]
    Rsky = frame.R[skyfibers]

    input_ivar = None
    if model_ivar:
        log.info(
            "use a model of the inverse variance to remove bias due to correlated ivar and flux"
        )
        input_ivar = current_ivar.copy()
        median_ivar_vs_wave = np.median(current_ivar, axis=0)
        median_ivar_vs_fiber = np.median(current_ivar, axis=1)
        median_median_ivar = np.median(median_ivar_vs_fiber)
        for f in range(current_ivar.shape[0]):
            threshold = 0.01
            current_ivar[f] = median_ivar_vs_fiber[
                f] / median_median_ivar * median_ivar_vs_wave
            # keep input ivar for very low weights
            ii = (input_ivar[f] <= (threshold * median_ivar_vs_wave))
            #log.info("fiber {} keep {}/{} original ivars".format(f,np.sum(ii),current_ivar.shape[1]))
            current_ivar[f][ii] = input_ivar[f][ii]

    # need focal plane coordinates
    x = frame.fibermap["FIBERASSIGN_X"]
    y = frame.fibermap["FIBERASSIGN_Y"]

    # normalize for numerical stability
    xm = np.mean(x)
    ym = np.mean(y)
    xs = np.std(x)
    ys = np.std(y)
    if xs == 0: xs = 1
    if ys == 0: ys = 1
    x = (x - xm) / xs
    y = (y - ym) / ys
    w = (frame.wave - frame.wave[0]) / (frame.wave[-1] -
                                        frame.wave[0]) * 2. - 1

    # precompute the monomials for the sky fibers
    log.debug("compute monomials for deg={} and {}".format(
        angular_variation_deg, chromatic_variation_deg))
    monomials = []
    for dx in range(angular_variation_deg + 1):
        for dy in range(angular_variation_deg + 1 - dx):
            xypol = (x**dx) * (y**dy)
            for dw in range(chromatic_variation_deg + 1):
                wpol = w**dw
                monomials.append(np.outer(xypol, wpol))

    ncoef = len(monomials)
    coef = np.zeros((ncoef))

    allfibers_monomials = np.array(monomials)
    log.debug("shape of allfibers_monomials = {}".format(
        allfibers_monomials.shape))

    skyfibers_monomials = allfibers_monomials[:, skyfibers, :]
    log.debug("shape of skyfibers_monomials = {}".format(
        skyfibers_monomials.shape))

    sqrtw = np.sqrt(current_ivar)
    sqrtwflux = sqrtw * flux

    chi2 = np.zeros(flux.shape)

    Pol = np.ones(flux.shape, dtype=float)
    coef[0] = 1.

    nout_tot = 0
    previous_chi2 = -10.
    for iteration in range(max_iterations):

        # the matrix A is 1/2 of the second derivative of the chi2 with respect to the parameters
        # A_ij = 1/2 d2(chi2)/di/dj
        # A_ij = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * d(model)/dj[fiber,w]

        # the vector B is 1/2 of the first derivative of the chi2 with respect to the parameters
        # B_i  = 1/2 d(chi2)/di
        # B_i  = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * (flux[fiber,w]-model[fiber,w])

        # the model is model[fiber]=R[fiber]*Pol(x,y,wave)*sky
        # the parameters are the unconvolved sky flux at the wavelength i
        # and the polynomial coefficients

        A = np.zeros((nwave, nwave), dtype=float)
        B = np.zeros((nwave), dtype=float)
        D = scipy.sparse.lil_matrix((nwave, nwave))
        D2 = scipy.sparse.lil_matrix((nwave, nwave))

        Pol /= coef[0]  # force constant term to 1.

        # solving for the deconvolved mean sky spectrum
        # loop on fiber to handle resolution
        for fiber in range(nfibers):
            if fiber % 10 == 0:
                log.info("iter %d sky fiber (1st fit) %d/%d" %
                         (iteration, fiber, nfibers))
            D.setdiag(sqrtw[fiber])
            D2.setdiag(Pol[fiber])
            sqrtwRP = D.dot(Rsky[fiber]).dot(
                D2)  # each row r of R is multiplied by sqrtw[r]
            A += (sqrtwRP.T * sqrtwRP).todense()
            B += sqrtwRP.T * sqrtwflux[fiber]

        log.info("iter %d solving" % iteration)
        w = A.diagonal() > 0
        A_pos_def = A[w, :]
        A_pos_def = A_pos_def[:, w]
        parameters = B * 0
        try:
            parameters[w] = cholesky_solve(A_pos_def, B[w])
        except:
            log.info("cholesky failed, trying svd in iteration {}".format(
                iteration))
            parameters[w] = np.linalg.lstsq(A_pos_def, B[w])[0]
        # parameters = the deconvolved mean sky spectrum

        # now evaluate the polynomial coefficients
        Ap = np.zeros((ncoef, ncoef), dtype=float)
        Bp = np.zeros((ncoef), dtype=float)
        D2.setdiag(parameters)
        for fiber in range(nfibers):
            if fiber % 10 == 0:
                log.info("iter %d sky fiber  (2nd fit) %d/%d" %
                         (iteration, fiber, nfibers))
            D.setdiag(sqrtw[fiber])
            sqrtwRSM = D.dot(Rsky[fiber]).dot(D2).dot(
                skyfibers_monomials[:, fiber, :].T)
            Ap += sqrtwRSM.T.dot(sqrtwRSM)
            Bp += sqrtwRSM.T.dot(sqrtwflux[fiber])

        # Add huge prior on zeroth angular order terms to converge faster
        # (because those terms are degenerate with the mean deconvolved spectrum)
        weight = 1e24
        Ap[0, 0] += weight
        Bp[0] += weight  # force 0th term to 1
        for i in range(1, chromatic_variation_deg + 1):
            Ap[i, i] += weight  # force other wavelength terms to 0

        coef = cholesky_solve(Ap, Bp)
        log.info("pol coef = {}".format(coef))

        # recompute the polynomial values
        Pol = skyfibers_monomials.T.dot(coef).T

        # chi2 and outlier rejection
        log.info("iter %d compute chi2" % iteration)
        for fiber in range(nfibers):
            chi2[fiber] = current_ivar[fiber] * (
                flux[fiber] - Rsky[fiber].dot(Pol[fiber] * parameters))**2

        log.info("rejecting")

        nout_iter = 0
        if iteration < 1:
            # only remove worst outlier per wave
            # apply rejection iteratively, only one entry per wave among fibers
            # find waves with outlier (fastest way)
            nout_per_wave = np.sum(chi2 > nsig_clipping**2, axis=0)
            selection = np.where(nout_per_wave > 0)[0]
            for i in selection:
                worst_entry = np.argmax(chi2[:, i])
                current_ivar[worst_entry, i] = 0
                sqrtw[worst_entry, i] = 0
                sqrtwflux[worst_entry, i] = 0
                nout_iter += 1

        else:
            # remove all of them at once
            bad = (chi2 > nsig_clipping**2)
            current_ivar *= (bad == 0)
            sqrtw *= (bad == 0)
            sqrtwflux *= (bad == 0)
            nout_iter += np.sum(bad)

        nout_tot += nout_iter

        sum_chi2 = float(np.sum(chi2))
        ndf = int(np.sum(chi2 > 0) - nwave)
        chi2pdf = 0.
        if ndf > 0:
            chi2pdf = sum_chi2 / ndf

        log.info("iter #%d chi2=%g ndf=%d chi2pdf=%f delta=%f nout=%d" %
                 (iteration, sum_chi2, ndf, chi2pdf,
                  abs(sum_chi2 - previous_chi2), nout_iter))

        if nout_iter == 0 and abs(sum_chi2 - previous_chi2) < 0.2:
            break
        previous_chi2 = sum_chi2 + 0.

    log.info("nout tot=%d" % nout_tot)

    # we know have to compute the sky model for all fibers
    # and propagate the uncertainties

    # no need to restore the original ivar to compute the model errors when modeling ivar
    # the sky inverse variances are very similar

    # we ignore here the fact that we have fit a angular variation,
    # so the sky model uncertainties are inaccurate

    log.info("compute the parameter covariance")
    try:
        parameter_covar = cholesky_invert(A)
    except np.linalg.linalg.LinAlgError:
        log.warning(
            "cholesky_solve_and_invert failed, switching to np.linalg.lstsq and np.linalg.pinv"
        )
        parameter_covar = np.linalg.pinv(A)

    log.info("compute mean resolution")
    # we make an approximation for the variance to save CPU time
    # we use the average resolution of all fibers in the frame:
    mean_res_data = np.mean(frame.resolution_data, axis=0)
    Rmean = Resolution(mean_res_data)

    log.info("compute convolved sky and ivar")

    # The parameters are directly the unconvolved sky
    # First convolve with average resolution :
    convolved_sky_covar = Rmean.dot(parameter_covar).dot(Rmean.T.todense())

    # and keep only the diagonal
    convolved_sky_var = np.diagonal(convolved_sky_covar)

    # inverse
    convolved_sky_ivar = (convolved_sky_var > 0) / (convolved_sky_var +
                                                    (convolved_sky_var == 0))

    # and simply consider it's the same for all spectra
    cskyivar = np.tile(convolved_sky_ivar,
                       frame.nspec).reshape(frame.nspec, nwave)

    # The sky model for each fiber (simple convolution with resolution of each fiber)
    cskyflux = np.zeros(frame.flux.shape)

    Pol = allfibers_monomials.T.dot(coef).T
    for fiber in range(frame.nspec):
        cskyflux[fiber] = frame.R[fiber].dot(Pol[fiber] * parameters)

    # look at chi2 per wavelength and increase sky variance to reach chi2/ndf=1
    if skyfibers.size > 1 and add_variance:
        modified_cskyivar = _model_variance(frame, cskyflux, cskyivar,
                                            skyfibers)
    else:
        modified_cskyivar = cskyivar.copy()

    # need to do better here
    mask = (cskyivar == 0).astype(np.uint32)

    return SkyModel(
        frame.wave.copy(),
        cskyflux,
        modified_cskyivar,
        mask,
        nrej=nout_tot,
        stat_ivar=cskyivar)  # keep a record of the statistical ivar for QA
Beispiel #58
0
def main(args=None):

    log = get_logger()

    if args is None:
        args = parse()

    if args.lin_step is not None and args.log10_step is not None:
        log.critical(
            "cannot have both linear and logarthmic bins :-), choose either --lin-step or --log10-step"
        )
        return 12
    if args.coadd_cameras and (args.lin_step is not None
                               or args.log10_step is not None):
        log.critical(
            "cannot specify a new wavelength binning along with --coadd-cameras option"
        )
        return 12

    if len(args.infile) == 0:
        log.critical("You must specify input files")
        return 12

    log.info("reading input ...")

    # inspect headers
    input_is_frames = False
    input_is_spectra = False
    for filename in args.infile:
        ifile = fitsio.FITS(filename)
        head = ifile[0].read_header()
        identified = False
        if "EXTNAME" in head and head["EXTNAME"] == "FLUX":
            print(filename, "is a frame")
            input_is_frames = True
            identified = True
            ifile.close()
            continue
        for hdu in ifile:
            head = hdu.read_header()
            if "EXTNAME" in head and head["EXTNAME"].find("_FLUX") >= 0:
                print(filename, "is a spectra")
                input_is_spectra = True
                identified = True
                break
        ifile.close()
        if not identified:
            log.error(
                "{} not identified as frame of spectra file".format(filename))
            sys.exit(1)

    if input_is_frames and input_is_spectra:
        log.error("cannot combine input spectra and frames")
        sys.exit(1)

    if input_is_spectra:
        spectra = read_spectra(args.infile[0])
        for filename in args.infile[1:]:
            log.info("append {}".format(filename))
            spectra.update(read_spectra(filename))
    else:  # frames
        frames = dict()
        cameras = {}
        for filename in args.infile:
            frame = read_frame(filename)
            night = frame.meta['NIGHT']
            expid = frame.meta['EXPID']
            camera = frame.meta['CAMERA']
            frames[(night, expid, camera)] = frame
            if args.coadd_cameras:
                cam, spec = camera[0], camera[1]
                # Keep a list of cameras (b,r,z) for each exposure + spec
                if (night, expid) not in cameras.keys():
                    cameras[(night, expid)] = {spec: [cam]}
                elif spec not in cameras[(night, expid)].keys():
                    cameras[(night, expid)][spec] = [cam]
                else:
                    cameras[(night, expid)][spec].append(cam)

        if args.coadd_cameras:
            # If not all 3 cameras are available, remove the incomplete sets
            for (night, expid), camdict in cameras.items():
                for spec, camlist in camdict.items():
                    log.info("Found {} for SP{} on NIGHT {} EXP {}".format(
                        camlist, spec, night, expid))
                    if len(camlist) != 3 or np.any(
                            np.sort(camlist) != np.array(['b', 'r', 'z'])):
                        for cam in camlist:
                            frames.pop((night, expid, cam + spec))
                            log.warning(
                                "Removing {}{} from Night {} EXP {}".format(
                                    cam, spec, night, expid))
        #import pdb
        #pdb.set_trace()
        spectra = frames2spectra(frames)

        #- hacks to make SpectraLite like a Spectra
        spectra.fibermap = Table(spectra.fibermap)

        del frames  #- maybe free some memory

    if args.coadd_cameras:
        log.info("coadding cameras ...")
        spectra = coadd_cameras(spectra, cosmics_nsig=args.nsig)
    else:
        log.info("coadding ...")
        coadd(spectra, cosmics_nsig=args.nsig)

    if args.lin_step is not None:
        log.info("resampling ...")
        spectra = resample_spectra_lin_or_log(spectra,
                                              linear_step=args.lin_step,
                                              wave_min=args.wave_min,
                                              wave_max=args.wave_max,
                                              fast=args.fast,
                                              nproc=args.nproc)
    if args.log10_step is not None:
        log.info("resampling ...")
        spectra = resample_spectra_lin_or_log(spectra,
                                              log10_step=args.log10_step,
                                              wave_min=args.wave_min,
                                              wave_max=args.wave_max,
                                              fast=args.fast,
                                              nproc=args.nproc)

    #- Add input files to header
    if spectra.meta is None:
        spectra.meta = dict()

    for i, filename in enumerate(args.infile):
        spectra.meta['INFIL{:03d}'.format(i)] = os.path.basename(filename)

    log.info("writing {} ...".format(args.outfile))
    write_spectra(args.outfile, spectra)

    log.info("done")
Beispiel #59
0
def read_raw(filename, camera, **kwargs):
    '''
    Returns preprocessed raw data from `camera` extension of `filename`

    Args:
        filename : input fits filename with DESI raw data
        camera : camera name (B0,R1, .. Z9) or FITS extension name or number

    Options:
        Other keyword arguments are passed to desispec.preproc.preproc(),
        e.g. bias, pixflat, mask.  See preproc() documentation for details.

    Returns Image object with member variables pix, ivar, mask, readnoise
    '''

    log = get_logger()

    fx = fits.open(filename, memmap=False)
    if camera.upper() not in fx:
        raise IOError('Camera {} not in {}'.format(camera, filename))

    rawimage = fx[camera.upper()].data
    header = fx[camera.upper()].header
    primary_header = fx[0].header

    blacklist = [
        "EXTEND", "SIMPLE", "NAXIS1", "NAXIS2", "CHECKSUM", "DATASUM",
        "XTENSION", "EXTNAME", "COMMENT"
    ]
    if 'INHERIT' in header and header['INHERIT']:
        h0 = fx[0].header
        for key in h0:
            if (key not in blacklist) and (key not in header):
                header[key] = h0[key]

    if "fill_header" in kwargs:
        hdus = kwargs["fill_header"]
        if hdus is not None:
            log.info("will add header keywords from hdus %s" % str(hdus))
            for hdu in hdus:
                try:
                    ihdu = int(hdu)
                    hdu = ihdu
                except ValueError:
                    pass
                if hdu in fx:
                    hdu_header = fx[hdu].header
                    for key in hdu_header:
                        if (key not in blacklist) and (key not in header):
                            log.debug("adding {} = {}".format(
                                key, hdu_header[key]))
                            header[key] = hdu_header[key]
                        else:
                            log.debug(
                                "key %s already in header or blacklisted" %
                                key)
                else:
                    log.warning("warning HDU %s not in fits file" % str(hdu))

        kwargs.pop("fill_header")

    fx.close()

    img = desispec.preproc.preproc(rawimage, header, primary_header, **kwargs)
    return img
Beispiel #60
0
def make_mtl(targets, zcat=None, trim=False):
    """Adds NUMOBS, PRIORITY, and OBSCONDITIONS columns to a targets table.

    Parameters
    ----------
    targets : :class:`~numpy.array` or `~astropy.table.Table`
        A numpy rec array or astropy Table with at least the columns
        ``TARGETID``, ``DESI_TARGET``, ``NUMOBS_INIT``, ``PRIORITY_INIT``.
        or the corresponding columns for SV or commissioning.
    zcat : :class:`~astropy.table.Table`, optional
        Redshift catalog table with columns ``TARGETID``, ``NUMOBS``, ``Z``,
        ``ZWARN``.
    trim : :class:`bool`, optional
        If ``True`` (default), don't include targets that don't need
        any more observations.  If ``False``, include every input target.

    Returns
    -------
    :class:`~astropy.table.Table`
        MTL Table with targets columns plus:

        * NUMOBS_MORE    - number of additional observations requested
        * PRIORITY       - target priority (larger number = higher priority)
        * OBSCONDITIONS  - replaces old GRAYLAYER
    """
    # ADM set up the default logger.
    from desiutil.log import get_logger
    log = get_logger()

    # ADM determine whether the input targets are main survey, cmx or SV.
    colnames, masks, survey = main_cmx_or_sv(targets)
    # ADM set the first column to be the "desitarget" column
    desi_target, desi_mask = colnames[0], masks[0]

    # Trim targets from zcat that aren't in original targets table
    if zcat is not None:
        ok = np.in1d(zcat['TARGETID'], targets['TARGETID'])
        num_extra = np.count_nonzero(~ok)
        if num_extra > 0:
            log.warning("Ignoring {} zcat entries that aren't "
                        "in the input target list".format(num_extra))
            zcat = zcat[ok]

    n = len(targets)
    # ADM if the input target columns were incorrectly called NUMOBS or PRIORITY
    # ADM rename them to NUMOBS_INIT or PRIORITY_INIT.
    # ADM Note that the syntax is slightly different for a Table.
    for name in ['NUMOBS', 'PRIORITY']:
        if isinstance(targets, Table):
            try:
                targets.rename_column(name, name + '_INIT')
            except KeyError:
                pass
        else:
            targets.dtype.names = [
                name + '_INIT' if col == name else col
                for col in targets.dtype.names
            ]

    # ADM if a redshift catalog was passed, order it to match the input targets
    # ADM catalog on 'TARGETID'.
    if zcat is not None:
        # ADM there might be a quicker way to do this?
        # ADM set up a dictionary of the indexes of each target id.
        d = dict(tuple(zip(targets["TARGETID"], np.arange(n))))
        # ADM loop through the zcat and look-up the index in the dictionary.
        zmatcher = np.array([d[tid] for tid in zcat["TARGETID"]])
        ztargets = zcat
        if ztargets.masked:
            unobs = ztargets['NUMOBS'].mask
            ztargets['NUMOBS'][unobs] = 0
            unobsz = ztargets['Z'].mask
            ztargets['Z'][unobsz] = -1
            unobszw = ztargets['ZWARN'].mask
            ztargets['ZWARN'][unobszw] = -1
    else:
        ztargets = Table()
        ztargets['TARGETID'] = targets['TARGETID']
        ztargets['NUMOBS'] = np.zeros(n, dtype=np.int32)
        ztargets['Z'] = -1 * np.ones(n, dtype=np.float32)
        ztargets['ZWARN'] = -1 * np.ones(n, dtype=np.int32)
        # ADM if zcat wasn't passed, there is a one-to-one correspondence
        # ADM between the targets and the zcat.
        zmatcher = np.arange(n)

    # ADM extract just the targets that match the input zcat.
    targets_zmatcher = targets[zmatcher]

    # ADM use passed value of NUMOBS_INIT instead of calling the memory-heavy calc_numobs.
    # ztargets['NUMOBS_MORE'] = np.maximum(0, calc_numobs(ztargets) - ztargets['NUMOBS'])
    ztargets['NUMOBS_MORE'] = np.maximum(
        0, targets_zmatcher['NUMOBS_INIT'] - ztargets['NUMOBS'])

    # ADM we need a minor hack to ensure that BGS targets are observed once (and only once)
    # ADM every time, regardless of how many times they've previously been observed.
    # ADM I've turned this off for commissioning. Not sure if we'll keep it in general.
    if survey != 'cmx':
        ii = targets_zmatcher[desi_target] & desi_mask.BGS_ANY > 0
        ztargets['NUMOBS_MORE'][ii] = 1

    # ADM assign priorities, note that only things in the zcat can have changed priorities.
    # ADM anything else will be assigned PRIORITY_INIT, below.
    priority = calc_priority(targets_zmatcher, ztargets)

    # If priority went to 0==DONOTOBSERVE or 1==OBS or 2==DONE, then NUMOBS_MORE should also be 0.
    # ## mtl['NUMOBS_MORE'] = ztargets['NUMOBS_MORE']
    ii = (priority <= 2)
    log.info(
        '{:d} of {:d} targets have priority zero, setting N_obs=0.'.format(
            np.sum(ii), n))
    ztargets['NUMOBS_MORE'][ii] = 0

    # - Set the OBSCONDITIONS mask for each target bit.
    obscon = set_obsconditions(targets)

    # ADM set up the output mtl table.
    mtl = Table(targets)
    mtl.meta['EXTNAME'] = 'MTL'
    # ADM any target that wasn't matched to the ZCAT should retain its
    # ADM original (INIT) value of PRIORITY and NUMOBS.
    mtl['NUMOBS_MORE'] = mtl['NUMOBS_INIT']
    mtl['PRIORITY'] = mtl['PRIORITY_INIT']
    # ADM now populate the new mtl columns with the updated information.
    mtl['OBSCONDITIONS'] = obscon
    mtl['PRIORITY'][zmatcher] = priority
    mtl['NUMOBS_MORE'][zmatcher] = ztargets['NUMOBS_MORE']

    # Filter out any targets marked as done.
    if trim:
        notdone = mtl['NUMOBS_MORE'] > 0
        log.info('{:d} of {:d} targets are done, trimming these'.format(
            len(mtl) - np.sum(notdone), len(mtl)))
        mtl = mtl[notdone]

    # Filtering can reset the fill_value, which is just wrong wrong wrong
    # See https://github.com/astropy/astropy/issues/4707
    # and https://github.com/astropy/astropy/issues/4708
    mtl['NUMOBS_MORE'].fill_value = -1

    return mtl