示例#1
0
def main(args, comm=None):

    log = get_logger()

    if args.npoly < 0:
        log.warning("Need npoly>=0, changing this %d -> 1" % args.npoly)
        args.npoly = 0
    if args.nproc < 1:
        log.warning("Need nproc>=1, changing this %d -> 1" % args.nproc)
        args.nproc = 1

    if comm is not None:
        if args.nproc != 1:
            if comm.rank == 0:
                log.warning("Using MPI, forcing multiprocessing nproc -> 1")
            args.nproc = 1

    if args.objtype is not None:
        args.objtype = args.objtype.split(',')

    #- Read brick files for each channel
    if (comm is None) or (comm.rank == 0):
        log.info("Reading bricks")
    brick = dict()
    if args.brick is not None:
        if len(args.brickfiles) != 0:
            raise RuntimeError(
                'Give -b/--brick or input brickfiles but not both')
        for channel in ('b', 'r', 'z'):
            filename = None
            if (comm is None) or (comm.rank == 0):
                filename = io.findfile('brick',
                                       band=channel,
                                       brickname=args.brick,
                                       specprod_dir=args.specprod_dir)
            if comm is not None:
                filename = comm.bcast(filename, root=0)
            brick[channel] = io.Brick(filename)
    else:
        for filename in args.brickfiles:
            bx = io.Brick(filename)
            if bx.channel not in brick:
                brick[bx.channel] = bx
            else:
                if (comm is None) or (comm.rank == 0):
                    log.error('Channel {} in multiple input files'.format(
                        bx.channel))
                sys.exit(2)

    filters = brick.keys()
    for fil in filters:
        if (comm is None) or (comm.rank == 0):
            log.info("Filter found: " + fil)

    #- Assume all channels have the same number of targets
    #- TODO: generalize this to allow missing channels
    #if args.nspec is None:
    #    args.nspec = brick['b'].get_num_targets()
    #    log.info("Fitting {} targets".format(args.nspec))
    #else:
    #    log.info("Fitting {} of {} targets".format(args.nspec, brick['b'].get_num_targets()))

    #- Coadd individual exposures and combine channels
    #- Full coadd code is a bit slow, so try something quick and dirty for
    #- now to get something going for redshifting
    if (comm is None) or (comm.rank == 0):
        log.info("Combining individual channels and exposures")
    wave = []
    for fil in filters:
        wave = np.concatenate([wave, brick[fil].get_wavelength_grid()])
    np.ndarray.sort(wave)
    nwave = len(wave)

    #- flux and ivar arrays to fill for all targets
    #flux = np.zeros((nspec, nwave))
    #ivar = np.zeros((nspec, nwave))
    flux = []
    ivar = []
    good_targetids = []
    targetids = brick['b'].get_target_ids()

    fpinfo = None
    if args.print_info is not None:
        if (comm is None) or (comm.rank == 0):
            fpinfo = open(args.print_info, "w")

    for i, targetid in enumerate(targetids):
        #- wave, flux, and ivar for this target; concatenate
        xwave = list()
        xflux = list()
        xivar = list()

        good = True
        for channel in filters:
            exp_flux, exp_ivar, resolution, info = brick[channel].get_target(
                targetid)
            weights = np.sum(exp_ivar, axis=0)
            ii, = np.where(weights > 0)
            if len(ii) == 0:
                good = False
                break
            xwave.extend(brick[channel].get_wavelength_grid()[ii])
            #- Average multiple exposures on the same wavelength grid for each channel
            xflux.extend(
                np.average(exp_flux[:, ii], weights=exp_ivar[:, ii], axis=0))
            xivar.extend(weights[ii])

        if not good:
            continue

        xwave = np.array(xwave)
        xivar = np.array(xivar)
        xflux = np.array(xflux)

        ii = np.argsort(xwave)
        #flux[i], ivar[i] = resample_flux(wave, xwave[ii], xflux[ii], xivar[ii])
        fl, iv = resample_flux(wave, xwave[ii], xflux[ii], xivar[ii])
        flux.append(fl)
        ivar.append(iv)
        good_targetids.append(targetid)
        if not args.print_info is None:
            s2n = np.median(fl[:-1] * np.sqrt(iv[:-1]) /
                            np.sqrt(wave[1:] - wave[:-1]))
            if (comm is None) or (comm.rank == 0):
                print targetid, s2n
                fpinfo.write(str(targetid) + " " + str(s2n) + "\n")

    if not args.print_info is None:
        if (comm is None) or (comm.rank == 0):
            fpinfo.close()
        sys.exit()

    good_targetids = good_targetids[args.first_spec:]
    flux = np.array(flux[args.first_spec:])
    ivar = np.array(ivar[args.first_spec:])
    nspec = len(good_targetids)
    if (comm is None) or (comm.rank == 0):
        log.info("number of good targets = %d" % nspec)
    if (args.nspec is not None) and (args.nspec < nspec):
        if (comm is None) or (comm.rank == 0):
            log.info("Fitting {} of {} targets".format(args.nspec, nspec))
        nspec = args.nspec
        good_targetids = good_targetids[:nspec]
        flux = flux[:nspec]
        ivar = ivar[:nspec]
    else:
        if (comm is None) or (comm.rank == 0):
            log.info("Fitting {} targets".format(nspec))

    if (comm is None) or (comm.rank == 0):
        log.debug("flux.shape={}".format(flux.shape))

    zf = None
    if comm is None:
        # Use multiprocessing built in to RedMonster.

        zf = RedMonsterZfind(wave=wave,
                             flux=flux,
                             ivar=ivar,
                             objtype=args.objtype,
                             zrange_galaxy=args.zrange_galaxy,
                             zrange_qso=args.zrange_qso,
                             zrange_star=args.zrange_star,
                             nproc=args.nproc,
                             npoly=args.npoly)

    else:
        # Use MPI

        # distribute the spectra among processes
        my_firstspec, my_nspec = dist_uniform(nspec, comm.size, comm.rank)
        my_specs = slice(my_firstspec, my_firstspec + my_nspec)
        for p in range(comm.size):
            if p == comm.rank:
                if my_nspec > 0:
                    log.info("process {} fitting spectra {} - {}".format(
                        p, my_firstspec, my_firstspec + my_nspec - 1))
                else:
                    log.info("process {} idle".format(p))
                sys.stdout.flush()
            comm.barrier()

        # do redshift fitting on each process
        myzf = None
        if my_nspec > 0:
            savelevel = os.environ["DESI_LOGLEVEL"]
            os.environ["DESI_LOGLEVEL"] = "WARNING"
            myzf = RedMonsterZfind(wave=wave,
                                   flux=flux[my_specs, :],
                                   ivar=ivar[my_specs, :],
                                   objtype=args.objtype,
                                   zrange_galaxy=args.zrange_galaxy,
                                   zrange_qso=args.zrange_qso,
                                   zrange_star=args.zrange_star,
                                   nproc=args.nproc,
                                   npoly=args.npoly)
            os.environ["DESI_LOGLEVEL"] = savelevel

        # Combine results into a single ZFindBase object on the root process.
        # We could do this with a gather, but we are using a small number of
        # processes, and point-to-point communication is easier for people to
        # understand.

        if comm.rank == 0:
            zf = ZfindBase(myzf.wave,
                           np.zeros((nspec, myzf.nwave)),
                           np.zeros((nspec, myzf.nwave)),
                           R=None,
                           results=None)

        for p in range(comm.size):
            if comm.rank == 0:
                if p == 0:
                    # root process copies its own data into output
                    zf.flux[my_specs] = myzf.flux
                    zf.ivar[my_specs] = myzf.ivar
                    zf.model[my_specs] = myzf.model
                    zf.z[my_specs] = myzf.z
                    zf.zerr[my_specs] = myzf.zerr
                    zf.zwarn[my_specs] = myzf.zwarn
                    zf.spectype[my_specs] = myzf.spectype
                    zf.subtype[my_specs] = myzf.subtype
                else:
                    # root process receives from process p and copies
                    # it into the output.
                    p_nspec = comm.recv(source=p, tag=0)
                    # only proceed if the sending process actually
                    # has some spectra assigned to it.
                    if p_nspec > 0:
                        p_firstspec = comm.recv(source=p, tag=1)
                        p_slice = slice(p_firstspec, p_firstspec + p_nspec)

                        p_flux = comm.recv(source=p, tag=2)
                        zf.flux[p_slice] = p_flux

                        p_ivar = comm.recv(source=p, tag=3)
                        zf.ivar[p_slice] = p_ivar

                        p_model = comm.recv(source=p, tag=4)
                        zf.model[p_slice] = p_model

                        p_z = comm.recv(source=p, tag=5)
                        zf.z[p_slice] = p_z

                        p_zerr = comm.recv(source=p, tag=6)
                        zf.zerr[p_slice] = p_zerr

                        p_zwarn = comm.recv(source=p, tag=7)
                        zf.zwarn[p_slice] = p_zwarn

                        p_type = comm.recv(source=p, tag=8)
                        zf.spectype[p_slice] = p_type

                        p_subtype = comm.recv(source=p, tag=9)
                        zf.subtype[p_slice] = p_subtype
            else:
                if p == comm.rank:
                    # process p sends to root
                    comm.send(my_nspec, dest=0, tag=0)
                    if my_nspec > 0:
                        comm.send(my_firstspec, dest=0, tag=1)
                        comm.send(myzf.flux, dest=0, tag=2)
                        comm.send(myzf.ivar, dest=0, tag=3)
                        comm.send(myzf.model, dest=0, tag=4)
                        comm.send(myzf.z, dest=0, tag=5)
                        comm.send(myzf.zerr, dest=0, tag=6)
                        comm.send(myzf.zwarn, dest=0, tag=7)
                        comm.send(myzf.spectype, dest=0, tag=8)
                        comm.send(myzf.subtype, dest=0, tag=9)
            comm.barrier()

    if (comm is None) or (comm.rank == 0):
        # The full results exist only on the rank zero process.

        # reformat results
        dtype = list()

        dtype = [
            ('Z', zf.z.dtype),
            ('ZERR', zf.zerr.dtype),
            ('ZWARN', zf.zwarn.dtype),
            ('SPECTYPE', zf.spectype.dtype),
            ('SUBTYPE', zf.subtype.dtype),
        ]

        formatted_data = np.empty(nspec, dtype=dtype)
        formatted_data['Z'] = zf.z
        formatted_data['ZERR'] = zf.zerr
        formatted_data['ZWARN'] = zf.zwarn
        formatted_data['SPECTYPE'] = zf.spectype
        formatted_data['SUBTYPE'] = zf.subtype

        # Create a ZfindBase object with formatted results
        zfi = ZfindBase(None, None, None, results=formatted_data)
        zfi.nspec = nspec

        # QA
        if (args.qafile is not None) or (args.qafig is not None):
            log.info("performing skysub QA")
            # Load
            qabrick = load_qa_brick(args.qafile)
            # Run
            qabrick.run_qa('ZBEST', (zfi, brick))
            # Write
            if args.qafile is not None:
                write_qa_brick(args.qafile, qabrick)
                log.info("successfully wrote {:s}".format(args.qafile))
            # Figure(s)
            if args.qafig is not None:
                raise IOError("Not yet implemented")
                qa_plots.brick_zbest(args.qafig, zfi, qabrick)

        #- Write some output
        if args.outfile is None:
            args.outfile = io.findfile('zbest', brickname=args.brick)

        log.info("Writing " + args.outfile)
        #io.write_zbest(args.outfile, args.brick, targetids, zfi, zspec=args.zspec)
        io.write_zbest(args.outfile,
                       args.brick,
                       good_targetids,
                       zfi,
                       zspec=args.zspec)

    return
示例#2
0
文件: run.py 项目: rstaten/desispec
def run_step(step, rawdir, proddir, grph, opts, comm=None, taskproc=1):
    '''
    Run a whole single step of the pipeline.

    This function first takes the communicator and the requested processes
    per task and splits the communicator to form groups of processes of
    the desired size.  It then takes the full dependency graph and extracts 
    all the tasks for a given step.  These tasks are then distributed among
    the groups of processes.

    Each process group loops over its assigned tasks.  For each task, it
    redirects stdout/stderr to a per-task file and calls run_task().  If
    any process in the group throws an exception, then the traceback and
    all information (graph and options) needed to re-run the task are written
    to disk.

    After all process groups have finished, the state of the full graph is
    merged from all processes.  This way a failure of one process on one task
    will be propagated as a failed task to all processes.

    Args:
        step (str): the pipeline step to process.
        rawdir (str): the path to the raw data directory.
        proddir (str): the path to the production directory.
        grph (dict): the dependency graph.
        opts (dict): the global options.
        comm (mpi4py.Comm): the full communicator to use for whole step.
        taskproc (int): the number of processes to use for a single task.

    Returns:
        Nothing.
    '''
    log = get_logger()

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    if taskproc > nproc:
        raise RuntimeError("cannot have {} processes per task with only {} processes".format(taskproc, nproc))

    # Get the tasks that need to be done for this step.  Mark all completed
    # tasks as done.

    tasks = None
    if rank == 0:
        # For this step, compute all the tasks that we need to do
        alltasks = []
        for name, nd in sorted(list(grph.items())):
            if nd['type'] in step_file_types[step]:
                alltasks.append(name)

        # For each task, prune if it is finished
        tasks = []
        for t in alltasks:
            if 'state' in grph[t].keys():
                if grph[t]['state'] != 'done':
                    tasks.append(t)
            else:
                tasks.append(t)

    if comm is not None:
        tasks = comm.bcast(tasks, root=0)
        grph = comm.bcast(grph, root=0)

    ntask = len(tasks)

    # Get the options for this step.

    options = opts[step]

    # Now every process has the full list of tasks.  If we have multiple
    # processes for each task, split the communicator.

    comm_group = comm
    comm_rank = None
    group = rank
    ngroup = nproc
    group_rank = 0
    if comm is not None:
        if taskproc > 1:
            ngroup = int(nproc / taskproc)
            group = int(rank / taskproc)
            group_rank = rank % taskproc
            comm_group = comm.Split(color=group, key=group_rank)
            comm_rank = comm.Split(color=group_rank, key=group)
        else:
            comm_group = None
            comm_rank = comm

    # Now we divide up the tasks among the groups of processes as
    # equally as possible.

    group_ntask = 0
    group_firsttask = 0

    if group < ngroup:
        # only assign tasks to whole groups
        if ntask < ngroup:
            if group < ntask:
                group_ntask = 1
                group_firsttask = group
            else:
                group_ntask = 0
        else:
            if step == 'zfind':
                # We load balance the bricks across process groups based
                # on the number of targets per brick.  All bricks with 
                # < taskproc targets are weighted the same.

                if ntask <= ngroup:
                    # distribute uniform in this case
                    group_firsttask, group_ntask = dist_uniform(ntask, ngroup, group)
                else:
                    bricksizes = [ grph[x]['ntarget'] for x in tasks ]
                    worksizes = [ taskproc if (x < taskproc) else x for x in bricksizes ]

                    if rank == 0:
                        log.debug("zfind {} groups".format(ngroup))
                        workstr = ""
                        for w in worksizes:
                            workstr = "{}{} ".format(workstr, w)
                        log.debug("zfind work sizes = {}".format(workstr))

                    group_firsttask, group_ntask = dist_discrete(worksizes, ngroup, group)

                if group_rank == 0:
                    worksum = np.sum(worksizes[group_firsttask:group_firsttask+group_ntask])
                    log.debug("group {} has tasks {}-{} sum = {}".format(group, group_firsttask, group_firsttask+group_ntask-1, worksum))

            else:
                group_firsttask, group_ntask = dist_uniform(ntask, ngroup, group)

    # every group goes and does its tasks...

    faildir = os.path.join(proddir, 'run', 'failed')
    logdir = os.path.join(proddir, 'run', 'logs')

    if group_ntask > 0:
        for t in range(group_firsttask, group_firsttask + group_ntask):
            # if group_rank == 0:
            #     print("group {} starting task {}".format(group, tasks[t]))
            #     sys.stdout.flush()
            # slice out just the graph for this task

            (night, gname) = graph_name_split(tasks[t])
            nfaildir = os.path.join(faildir, night)
            nlogdir = os.path.join(logdir, night)

            tgraph = graph_slice(grph, names=[tasks[t]], deps=True)
            ffile = os.path.join(nfaildir, "{}_{}.yaml".format(step, tasks[t]))
            
            # For this task, we will temporarily redirect stdout and stderr
            # to a task-specific log file.

            with stdouterr_redirected(to=os.path.join(nlogdir, "{}.log".format(gname)), comm=comm_group):
                try:
                    # if the step previously failed, clear that file now
                    if group_rank == 0:
                        if os.path.isfile(ffile):
                            os.remove(ffile)
                    # if group_rank == 0:
                    #     print("group {} runtask {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                    log.debug("running step {} task {} (group {}/{} with {} processes)".format(step, tasks[t], (group+1), ngroup, taskproc))
                    run_task(step, rawdir, proddir, tgraph, options, comm=comm_group)
                    # mark step as done in our group's graph
                    # if group_rank == 0:
                    #     print("group {} start graph_mark {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                    graph_mark(grph, tasks[t], state='done', descend=False)
                    # if group_rank == 0:
                    #     print("group {} end graph_mark {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                except:
                    # The task threw an exception.  We want to dump all information
                    # that will be needed to re-run the run_task() function on just
                    # this task.
                    msg = "FAILED: step {} task {} (group {}/{} with {} processes)".format(step, tasks[t], (group+1), ngroup, taskproc)
                    log.error(msg)
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
                    log.error(''.join(lines))
                    fyml = {}
                    fyml['step'] = step
                    fyml['rawdir'] = rawdir
                    fyml['proddir'] = proddir
                    fyml['task'] = tasks[t]
                    fyml['graph'] = tgraph
                    fyml['opts'] = options
                    fyml['procs'] = taskproc
                    if not os.path.isfile(ffile):
                        log.error('Dumping yaml graph to '+ffile)
                        # we are the first process to hit this
                        with open(ffile, 'w') as f:
                            yaml.dump(fyml, f, default_flow_style=False)
                    # mark the step as failed in our group's local graph
                    graph_mark(grph, tasks[t], state='fail', descend=True)

        if comm_group is not None:
            comm_group.barrier()

    # Now we take the graphs from all groups and merge their states

    #sys.stdout.flush()
    if comm is not None:
        # print("proc {} hit merge barrier".format(rank))
        # sys.stdout.flush()
        # comm.barrier()
        if group_rank == 0:
            # print("proc {} joining merge".format(rank))
            # sys.stdout.flush()
            graph_merge_state(grph, comm=comm_rank)
        if comm_group is not None:
            # print("proc {} joining bcast".format(rank))
            # sys.stdout.flush()
            grph = comm_group.bcast(grph, root=0)

    return grph
示例#3
0
文件: zfind.py 项目: rstaten/desispec
def main(args, comm=None) :

    log = get_logger()

    if args.npoly < 0 :
        log.warning("Need npoly>=0, changing this %d -> 1"%args.npoly)
        args.npoly=0
    if args.nproc < 1 :
        log.warning("Need nproc>=1, changing this %d -> 1"%args.nproc)
        args.nproc=1
    
    if comm is not None:
        if args.nproc != 1:
            if comm.rank == 0:
                log.warning("Using MPI, forcing multiprocessing nproc -> 1")
            args.nproc = 1

    if args.objtype is not None:
        args.objtype = args.objtype.split(',')

    #- Read brick files for each channel
    if (comm is None) or (comm.rank == 0):
        log.info("Reading bricks")
    brick = dict()
    if args.brick is not None:
        if len(args.brickfiles) != 0:
            raise RuntimeError('Give -b/--brick or input brickfiles but not both')
        for channel in ('b', 'r', 'z'):
            filename = None
            if (comm is None) or (comm.rank == 0):
                filename = io.findfile('brick', band=channel, brickname=args.brick,
                                        specprod_dir=args.specprod_dir)
            if comm is not None:
                filename = comm.bcast(filename, root=0)
            brick[channel] = io.Brick(filename)
    else:
        for filename in args.brickfiles:
            bx = io.Brick(filename)
            if bx.channel not in brick:
                brick[bx.channel] = bx
            else:
                if (comm is None) or (comm.rank == 0):
                    log.error('Channel {} in multiple input files'.format(bx.channel))
                sys.exit(2)

    filters=brick.keys()
    for fil in filters:
        if (comm is None) or (comm.rank == 0):
            log.info("Filter found: "+fil)

    #- Assume all channels have the same number of targets
    #- TODO: generalize this to allow missing channels
    #if args.nspec is None:
    #    args.nspec = brick['b'].get_num_targets()
    #    log.info("Fitting {} targets".format(args.nspec))
    #else:
    #    log.info("Fitting {} of {} targets".format(args.nspec, brick['b'].get_num_targets()))
    
    #- Coadd individual exposures and combine channels
    #- Full coadd code is a bit slow, so try something quick and dirty for
    #- now to get something going for redshifting
    if (comm is None) or (comm.rank == 0):
        log.info("Combining individual channels and exposures")
    wave=[]
    for fil in filters:
        wave=np.concatenate([wave,brick[fil].get_wavelength_grid()])
    np.ndarray.sort(wave)
    nwave = len(wave)

    #- flux and ivar arrays to fill for all targets
    #flux = np.zeros((nspec, nwave))
    #ivar = np.zeros((nspec, nwave))
    flux = []
    ivar = []
    good_targetids=[]
    targetids = brick['b'].get_target_ids()

    fpinfo = None
    if args.print_info is not None:
        if (comm is None) or (comm.rank == 0):
            fpinfo = open(args.print_info,"w")

    for i, targetid in enumerate(targetids):
        #- wave, flux, and ivar for this target; concatenate
        xwave = list()
        xflux = list()
        xivar = list()

        good=True
        for channel in filters:
            exp_flux, exp_ivar, resolution, info = brick[channel].get_target(targetid)
            weights = np.sum(exp_ivar, axis=0)
            ii, = np.where(weights > 0)
            if len(ii)==0:
                good=False
                break
            xwave.extend(brick[channel].get_wavelength_grid()[ii])
            #- Average multiple exposures on the same wavelength grid for each channel
            xflux.extend(np.average(exp_flux[:,ii], weights=exp_ivar[:,ii], axis=0))
            xivar.extend(weights[ii])

        if not good:
            continue

        xwave = np.array(xwave)
        xivar = np.array(xivar)
        xflux = np.array(xflux)

        ii = np.argsort(xwave)
        #flux[i], ivar[i] = resample_flux(wave, xwave[ii], xflux[ii], xivar[ii])
        fl, iv = resample_flux(wave, xwave[ii], xflux[ii], xivar[ii])
        flux.append(fl)
        ivar.append(iv)
        good_targetids.append(targetid)
        if not args.print_info is None:
            s2n = np.median(fl[:-1]*np.sqrt(iv[:-1])/np.sqrt(wave[1:]-wave[:-1]))
            if (comm is None) or (comm.rank == 0):
                print targetid,s2n
                fpinfo.write(str(targetid)+" "+str(s2n)+"\n")

    if not args.print_info is None:
        if (comm is None) or (comm.rank == 0):
            fpinfo.close()
        sys.exit()

    good_targetids=good_targetids[args.first_spec:]
    flux=np.array(flux[args.first_spec:])
    ivar=np.array(ivar[args.first_spec:])
    nspec=len(good_targetids)
    if (comm is None) or (comm.rank == 0):
        log.info("number of good targets = %d"%nspec)
    if (args.nspec is not None) and (args.nspec < nspec):
        if (comm is None) or (comm.rank == 0):
            log.info("Fitting {} of {} targets".format(args.nspec, nspec))
        nspec=args.nspec
        good_targetids=good_targetids[:nspec]
        flux=flux[:nspec]
        ivar=ivar[:nspec]
    else :
        if (comm is None) or (comm.rank == 0):
            log.info("Fitting {} targets".format(nspec))
    
    if (comm is None) or (comm.rank == 0):
        log.debug("flux.shape={}".format(flux.shape))
    
    zf = None
    if comm is None:
        # Use multiprocessing built in to RedMonster.

        zf = RedMonsterZfind(wave= wave,flux= flux,ivar=ivar,
                             objtype=args.objtype,zrange_galaxy= args.zrange_galaxy,
                             zrange_qso=args.zrange_qso,zrange_star=args.zrange_star,
                             nproc=args.nproc,npoly=args.npoly)
    
    else:
        # Use MPI

        # distribute the spectra among processes
        my_firstspec, my_nspec = dist_uniform(nspec, comm.size, comm.rank)
        my_specs = slice(my_firstspec, my_firstspec + my_nspec)
        for p in range(comm.size):
            if p == comm.rank:
                if my_nspec > 0:
                    log.info("process {} fitting spectra {} - {}".format(p, my_firstspec, my_firstspec+my_nspec-1))
                else:
                    log.info("process {} idle".format(p))
                sys.stdout.flush()
            comm.barrier()

        # do redshift fitting on each process
        myzf = None
        if my_nspec > 0:
            savelevel = os.environ["DESI_LOGLEVEL"]
            os.environ["DESI_LOGLEVEL"] = "WARNING"
            myzf = RedMonsterZfind(wave=wave, flux=flux[my_specs,:], ivar=ivar[my_specs,:],
                             objtype=args.objtype,zrange_galaxy= args.zrange_galaxy,
                             zrange_qso=args.zrange_qso,zrange_star=args.zrange_star,
                             nproc=args.nproc,npoly=args.npoly)
            os.environ["DESI_LOGLEVEL"] = savelevel

        # Combine results into a single ZFindBase object on the root process.
        # We could do this with a gather, but we are using a small number of
        # processes, and point-to-point communication is easier for people to
        # understand.

        if comm.rank == 0:
            zf = ZfindBase(myzf.wave, np.zeros((nspec, myzf.nwave)), np.zeros((nspec, myzf.nwave)), R=None, results=None)
        
        for p in range(comm.size):
            if comm.rank == 0:
                if p == 0:
                    # root process copies its own data into output
                    zf.flux[my_specs] = myzf.flux
                    zf.ivar[my_specs] = myzf.ivar
                    zf.model[my_specs] = myzf.model
                    zf.z[my_specs] = myzf.z
                    zf.zerr[my_specs] = myzf.zerr
                    zf.zwarn[my_specs] = myzf.zwarn
                    zf.spectype[my_specs] = myzf.spectype
                    zf.subtype[my_specs] = myzf.subtype
                else:
                    # root process receives from process p and copies
                    # it into the output.
                    p_nspec = comm.recv(source=p, tag=0)
                    # only proceed if the sending process actually
                    # has some spectra assigned to it.
                    if p_nspec > 0:
                        p_firstspec = comm.recv(source=p, tag=1)
                        p_slice = slice(p_firstspec, p_firstspec+p_nspec)

                        p_flux = comm.recv(source=p, tag=2)
                        zf.flux[p_slice] = p_flux

                        p_ivar = comm.recv(source=p, tag=3)
                        zf.ivar[p_slice] = p_ivar

                        p_model = comm.recv(source=p, tag=4)
                        zf.model[p_slice] = p_model

                        p_z = comm.recv(source=p, tag=5)
                        zf.z[p_slice] = p_z

                        p_zerr = comm.recv(source=p, tag=6)
                        zf.zerr[p_slice] = p_zerr

                        p_zwarn = comm.recv(source=p, tag=7)
                        zf.zwarn[p_slice] = p_zwarn
                        
                        p_type = comm.recv(source=p, tag=8)
                        zf.spectype[p_slice] = p_type
                        
                        p_subtype = comm.recv(source=p, tag=9)
                        zf.subtype[p_slice] = p_subtype
            else:
                if p == comm.rank:
                    # process p sends to root
                    comm.send(my_nspec, dest=0, tag=0)
                    if my_nspec > 0:
                        comm.send(my_firstspec, dest=0, tag=1)
                        comm.send(myzf.flux, dest=0, tag=2)
                        comm.send(myzf.ivar, dest=0, tag=3)
                        comm.send(myzf.model, dest=0, tag=4)
                        comm.send(myzf.z, dest=0, tag=5)
                        comm.send(myzf.zerr, dest=0, tag=6)
                        comm.send(myzf.zwarn, dest=0, tag=7)
                        comm.send(myzf.spectype, dest=0, tag=8)
                        comm.send(myzf.subtype, dest=0, tag=9)
            comm.barrier()

    if (comm is None) or (comm.rank == 0):
        # The full results exist only on the rank zero process.

        # reformat results
        dtype = list()

        dtype = [
            ('Z',         zf.z.dtype),
            ('ZERR',      zf.zerr.dtype),
            ('ZWARN',     zf.zwarn.dtype),
            ('SPECTYPE',  zf.spectype.dtype),
            ('SUBTYPE',   zf.subtype.dtype),    
        ]

        formatted_data  = np.empty(nspec, dtype=dtype)
        formatted_data['Z']        = zf.z
        formatted_data['ZERR']     = zf.zerr
        formatted_data['ZWARN']    = zf.zwarn
        formatted_data['SPECTYPE'] = zf.spectype
        formatted_data['SUBTYPE']  = zf.subtype
        
        # Create a ZfindBase object with formatted results
        zfi = ZfindBase(None, None, None, results=formatted_data)
        zfi.nspec = nspec

        # QA
        if (args.qafile is not None) or (args.qafig is not None):
            log.info("performing skysub QA")
            # Load
            qabrick = load_qa_brick(args.qafile)
            # Run
            qabrick.run_qa('ZBEST', (zfi,brick))
            # Write
            if args.qafile is not None:
                write_qa_brick(args.qafile, qabrick)
                log.info("successfully wrote {:s}".format(args.qafile))
            # Figure(s)
            if args.qafig is not None:
                raise IOError("Not yet implemented")
                qa_plots.brick_zbest(args.qafig, zfi, qabrick)

        #- Write some output
        if args.outfile is None:
            args.outfile = io.findfile('zbest', brickname=args.brick)

        log.info("Writing "+args.outfile)
        #io.write_zbest(args.outfile, args.brick, targetids, zfi, zspec=args.zspec)
        io.write_zbest(args.outfile, args.brick, good_targetids, zfi, zspec=args.zspec)

    return
示例#4
0
def run_step(step, rawdir, proddir, grph, opts, comm=None, taskproc=1):
    log = get_logger()

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    if taskproc > nproc:
        raise RuntimeError("cannot have {} processes per task with only {} processes".format(taskproc, nproc))

    # Get the tasks that need to be done for this step.  Mark all completed
    # tasks as done.

    tasks = None
    if rank == 0:
        # For this step, compute all the tasks that we need to do
        alltasks = []
        for name, nd in sorted(list(grph.items())):
            if nd['type'] in step_file_types[step]:
                alltasks.append(name)

        # For each task, prune if it is finished
        tasks = []
        for t in alltasks:
            if 'state' in grph[t].keys():
                if grph[t]['state'] != 'done':
                    tasks.append(t)
            else:
                tasks.append(t)

    if comm is not None:
        tasks = comm.bcast(tasks, root=0)
        grph = comm.bcast(grph, root=0)

    ntask = len(tasks)

    # Get the options for this step.

    options = opts[step]

    # Now every process has the full list of tasks.  If we have multiple
    # processes for each task, split the communicator.

    comm_group = comm
    comm_rank = None
    group = rank
    ngroup = nproc
    group_rank = 0
    if comm is not None:
        if taskproc > 1:
            ngroup = int(nproc / taskproc)
            group = int(rank / taskproc)
            group_rank = rank % taskproc
            comm_group = comm.Split(color=group, key=group_rank)
            comm_rank = comm.Split(color=group_rank, key=group)
        else:
            comm_group = None
            comm_rank = comm

    # Now we divide up the tasks among the groups of processes as
    # equally as possible.

    group_ntask = 0
    group_firsttask = 0

    if group < ngroup:
        # only assign tasks to whole groups
        if ntask < ngroup:
            if group < ntask:
                group_ntask = 1
                group_firsttask = group
            else:
                group_ntask = 0
        else:
            if step == 'zfind':
                # We load balance the bricks across process groups based
                # on the number of targets per brick.  All bricks with 
                # < taskproc targets are weighted the same.

                if ntask <= ngroup:
                    # distribute uniform in this case
                    group_firsttask, group_ntask = dist_uniform(ntask, ngroup, group)
                else:
                    bricksizes = [ grph[x]['ntarget'] for x in tasks ]
                    worksizes = [ taskproc if (x < taskproc) else x for x in bricksizes ]

                    if rank == 0:
                        log.debug("zfind {} groups".format(ngroup))
                        workstr = ""
                        for w in worksizes:
                            workstr = "{}{} ".format(workstr, w)
                        log.debug("zfind work sizes = {}".format(workstr))

                    group_firsttask, group_ntask = dist_discrete(worksizes, ngroup, group)

                if group_rank == 0:
                    worksum = np.sum(worksizes[group_firsttask:group_firsttask+group_ntask])
                    log.debug("group {} has tasks {}-{} sum = {}".format(group, group_firsttask, group_firsttask+group_ntask-1, worksum))

            else:
                group_firsttask, group_ntask = dist_uniform(ntask, ngroup, group)

    # every group goes and does its tasks...

    faildir = os.path.join(proddir, 'run', 'failed')
    logdir = os.path.join(proddir, 'run', 'logs')

    if group_ntask > 0:
        for t in range(group_firsttask, group_firsttask + group_ntask):
            # if group_rank == 0:
            #     print("group {} starting task {}".format(group, tasks[t]))
            #     sys.stdout.flush()
            # slice out just the graph for this task

            (night, gname) = graph_name_split(tasks[t])
            nfaildir = os.path.join(faildir, night)
            nlogdir = os.path.join(logdir, night)

            tgraph = graph_slice(grph, names=[tasks[t]], deps=True)
            ffile = os.path.join(nfaildir, "{}_{}.yaml".format(step, tasks[t]))
            
            # For this task, we will temporarily redirect stdout and stderr
            # to a task-specific log file.

            with stdouterr_redirected(to=os.path.join(nlogdir, "{}.log".format(gname)), comm=comm_group):
                try:
                    # if the step previously failed, clear that file now
                    if group_rank == 0:
                        if os.path.isfile(ffile):
                            os.remove(ffile)
                    # if group_rank == 0:
                    #     print("group {} runtask {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                    log.debug("running step {} task {} (group {}/{} with {} processes)".format(step, tasks[t], (group+1), ngroup, taskproc))
                    run_task(step, rawdir, proddir, tgraph, options, comm=comm_group)
                    # mark step as done in our group's graph
                    # if group_rank == 0:
                    #     print("group {} start graph_mark {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                    graph_mark(grph, tasks[t], state='done', descend=False)
                    # if group_rank == 0:
                    #     print("group {} end graph_mark {}".format(group, tasks[t]))
                    #     sys.stdout.flush()
                except:
                    # The task threw an exception.  We want to dump all information
                    # that will be needed to re-run the run_task() function on just
                    # this task.
                    msg = "FAILED: step {} task {} (group {}/{} with {} processes)".format(step, tasks[t], (group+1), ngroup, taskproc)
                    log.error(msg)
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
                    log.error(''.join(lines))
                    fyml = {}
                    fyml['step'] = step
                    fyml['rawdir'] = rawdir
                    fyml['proddir'] = proddir
                    fyml['task'] = tasks[t]
                    fyml['graph'] = tgraph
                    fyml['opts'] = options
                    fyml['procs'] = taskproc
                    if not os.path.isfile(ffile):
                        log.error('Dumping yaml graph to '+ffile)
                        # we are the first process to hit this
                        with open(ffile, 'w') as f:
                            yaml.dump(fyml, f, default_flow_style=False)
                    # mark the step as failed in our group's local graph
                    graph_mark(grph, tasks[t], state='fail', descend=True)

        if comm_group is not None:
            comm_group.barrier()

    # Now we take the graphs from all groups and merge their states

    #sys.stdout.flush()
    if comm is not None:
        # print("proc {} hit merge barrier".format(rank))
        # sys.stdout.flush()
        # comm.barrier()
        if group_rank == 0:
            # print("proc {} joining merge".format(rank))
            # sys.stdout.flush()
            graph_merge_state(grph, comm=comm_rank)
        if comm_group is not None:
            # print("proc {} joining bcast".format(rank))
            # sys.stdout.flush()
            grph = comm_group.bcast(grph, root=0)

    return grph
示例#5
0
文件: run.py 项目: secroun/desispec
def run_step(step, rawdir, proddir, grph, opts, comm=None, taskproc=1):
    '''
    Run a whole single step of the pipeline.

    This function first takes the communicator and the requested processes
    per task and splits the communicator to form groups of processes of
    the desired size.  It then takes the full dependency graph and extracts 
    all the tasks for a given step.  These tasks are then distributed among
    the groups of processes.

    Each process group loops over its assigned tasks.  For each task, it
    redirects stdout/stderr to a per-task file and calls run_task().  If
    any process in the group throws an exception, then the traceback and
    all information (graph and options) needed to re-run the task are written
    to disk.

    After all process groups have finished, the state of the full graph is
    merged from all processes.  This way a failure of one process on one task
    will be propagated as a failed task to all processes.

    Args:
        step (str): the pipeline step to process.
        rawdir (str): the path to the raw data directory.
        proddir (str): the path to the production directory.
        grph (dict): the dependency graph.
        opts (dict): the global options.
        comm (mpi4py.Comm): the full communicator to use for whole step.
        taskproc (int): the number of processes to use for a single task.

    Returns:
        Nothing.
    '''
    log = get_logger()

    nproc = 1
    rank = 0
    if comm is not None:
        nproc = comm.size
        rank = comm.rank

    if taskproc > nproc:
        raise RuntimeError(
            "cannot have {} processes per task with only {} processes".format(
                taskproc, nproc))

    # Get the tasks that need to be done for this step.  Mark all completed
    # tasks as done.

    tasks = None
    if rank == 0:
        # For this step, compute all the tasks that we need to do
        alltasks = []
        for name, nd in sorted(grph.items()):
            if nd['type'] in step_file_types[step]:
                alltasks.append(name)

        # For each task, prune if it is finished
        tasks = []
        for t in alltasks:
            if 'state' in grph[t]:
                if grph[t]['state'] != 'done':
                    tasks.append(t)
            else:
                tasks.append(t)

    if comm is not None:
        tasks = comm.bcast(tasks, root=0)
        grph = comm.bcast(grph, root=0)

    ntask = len(tasks)

    # Get the options for this step.

    options = opts[step]

    # Now every process has the full list of tasks.  If we have multiple
    # processes for each task, split the communicator.

    comm_group = comm
    comm_rank = None
    group = rank
    ngroup = nproc
    group_rank = 0
    if comm is not None:
        if taskproc > 1:
            ngroup = int(nproc / taskproc)
            group = int(rank / taskproc)
            group_rank = rank % taskproc
            comm_group = comm.Split(color=group, key=group_rank)
            comm_rank = comm.Split(color=group_rank, key=group)
        else:
            comm_group = None
            comm_rank = comm

    # Now we divide up the tasks among the groups of processes as
    # equally as possible.

    group_ntask = 0
    group_firsttask = 0

    if group < ngroup:
        # only assign tasks to whole groups
        if ntask < ngroup:
            if group < ntask:
                group_ntask = 1
                group_firsttask = group
            else:
                group_ntask = 0
        else:
            if step == 'zfind':
                # We load balance the bricks across process groups based
                # on the number of targets per brick.  All bricks with
                # < taskproc targets are weighted the same.

                if ntask <= ngroup:
                    # distribute uniform in this case
                    group_firsttask, group_ntask = dist_uniform(
                        ntask, ngroup, group)
                else:
                    bricksizes = [grph[x]['ntarget'] for x in tasks]
                    worksizes = [
                        taskproc if (x < taskproc) else x for x in bricksizes
                    ]

                    if rank == 0:
                        log.debug("zfind {} groups".format(ngroup))
                        workstr = ""
                        for w in worksizes:
                            workstr = "{}{} ".format(workstr, w)
                        log.debug("zfind work sizes = {}".format(workstr))

                    group_firsttask, group_ntask = dist_discrete(
                        worksizes, ngroup, group)

                if group_rank == 0:
                    worksum = np.sum(
                        worksizes[group_firsttask:group_firsttask +
                                  group_ntask])
                    log.debug("group {} has tasks {}-{} sum = {}".format(
                        group, group_firsttask,
                        group_firsttask + group_ntask - 1, worksum))

            else:
                group_firsttask, group_ntask = dist_uniform(
                    ntask, ngroup, group)

    # every group goes and does its tasks...

    faildir = os.path.join(proddir, 'run', 'failed')
    logdir = os.path.join(proddir, 'run', 'logs')

    failcount = 0
    group_failcount = 0

    if group_ntask > 0:
        for t in range(group_firsttask, group_firsttask + group_ntask):
            # if group_rank == 0:
            #     print("group {} starting task {}".format(group, tasks[t]))
            #     sys.stdout.flush()
            # slice out just the graph for this task

            (night, gname) = graph_name_split(tasks[t])

            # check if all inputs exist

            missing = 0
            if group_rank == 0:
                for iname in grph[tasks[t]]['in']:
                    ind = grph[iname]
                    fspath = graph_path(rawdir, proddir, iname, ind['type'])
                    if not os.path.exists(fspath):
                        missing += 1
                        log.error(
                            "skipping step {} task {} due to missing input {}".
                            format(step, tasks[t], fspath))

            if comm_group is not None:
                missing = comm_group.bcast(missing, root=0)

            if missing > 0:
                if group_rank == 0:
                    group_failcount += 1
                continue

            nfaildir = os.path.join(faildir, night)
            nlogdir = os.path.join(logdir, night)

            tgraph = graph_slice(grph, names=[tasks[t]], deps=True)
            ffile = os.path.join(nfaildir, "{}_{}.yaml".format(step, tasks[t]))

            # For this task, we will temporarily redirect stdout and stderr
            # to a task-specific log file.

            tasklog = os.path.join(nlogdir, "{}.log".format(gname))
            if group_rank == 0:
                if os.path.isfile(tasklog):
                    os.remove(tasklog)
            if comm_group is not None:
                comm_group.barrier()

            with stdouterr_redirected(to=tasklog, comm=comm_group):
                try:
                    # if the step previously failed, clear that file now
                    if group_rank == 0:
                        if os.path.isfile(ffile):
                            os.remove(ffile)

                    log.debug(
                        "running step {} task {} (group {}/{} with {} processes)"
                        .format(step, tasks[t], (group + 1), ngroup, taskproc))

                    # All processes in comm_group will either return from this or ALL will
                    # raise an exception
                    run_task(step,
                             rawdir,
                             proddir,
                             tgraph,
                             options,
                             comm=comm_group)

                    # mark step as done in our group's graph
                    graph_mark(grph, tasks[t], state='done', descend=False)

                except:
                    # The task threw an exception.  We want to dump all information
                    # that will be needed to re-run the run_task() function on just
                    # this task.
                    if group_rank == 0:
                        group_failcount += 1
                        msg = "FAILED: step {} task {} (group {}/{} with {} processes)".format(
                            step, tasks[t], (group + 1), ngroup, taskproc)
                        log.error(msg)
                        exc_type, exc_value, exc_traceback = sys.exc_info()
                        lines = traceback.format_exception(
                            exc_type, exc_value, exc_traceback)
                        log.error(''.join(lines))
                        fyml = {}
                        fyml['step'] = step
                        fyml['rawdir'] = rawdir
                        fyml['proddir'] = proddir
                        fyml['task'] = tasks[t]
                        fyml['graph'] = tgraph
                        fyml['opts'] = options
                        fyml['procs'] = taskproc
                        if not os.path.isfile(ffile):
                            log.error('Dumping yaml graph to ' + ffile)
                            # we are the first process to hit this
                            with open(ffile, 'w') as f:
                                yaml.dump(fyml, f, default_flow_style=False)
                    # mark the step as failed in our group's local graph
                    graph_mark(grph, tasks[t], state='fail', descend=True)

        if comm_group is not None:
            group_failcount = comm_group.bcast(group_failcount, root=0)

    # Now we take the graphs from all groups and merge their states

    failcount = group_failcount

    if comm is not None:
        if group_rank == 0:
            graph_merge_state(grph, comm=comm_rank)
            failcount = comm_rank.allreduce(failcount)
        if comm_group is not None:
            grph = comm_group.bcast(grph, root=0)
            failcount = comm_group.bcast(failcount, root=0)

    return grph, ntask, failcount