예제 #1
0
def run_tfdist(output, *actions):
    # type: (Path, *TAction) -> List[Workload]
    """Run a sequence of actions"""
    workloads = []  # type: List[Workload]

    try:
        with atomic_directory(output) as temp_dir:  # type: Path
            # start server
            ss = TFDistServer(outputdir=temp_dir)
            with ss.run():
                # Do action specified in seq
                for act in actions:
                    ss.check()

                    if isinstance(act, Workload):
                        if act.executor != Executor.TFDist:
                            raise ValueError('run_tfdist can only run TFDist workloads')
                        output_file = temp_dir / f'{act.output_name}.{act.batch_num}iter.{len(workloads)}.output'

                        act.run(output_file)
                        workloads.append(act)
                    elif isinstance(act, (Pause, RunFn)):
                        act.run(workloads, temp_dir=temp_dir)
                    else:
                        raise ValueError(f"Unexpected value `{act}' of {type(act)} passed to run_tfdist")

                logger.info(f'Waiting all workloads to finish')
                ss.wait_workloads(workloads)
    finally:
        # if there's alive, we are doing cleanup
        for w in workloads:
            if w.proc is not None and w.proc.poll() is None:
                logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}')
                kill_tree(w.proc, hard=True)

        # check each workloads and fix workload output_file path
        for w in workloads:
            if not FLAGS.ignore_error and w.proc.returncode != 0:
                prompt.pause()
                raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}')
            w.output_file = output
    return workloads
예제 #2
0
파일: exp15.py 프로젝트: vycezhong/Salus-1
def main(argv):
    # type: (Sequence[str]) -> None
    scfg = maybe_forced_preset(presets.MostEfficient)
    scfg.scheduler = 'pack'

    ex = Executor.Salus if FLAGS.use_salus else Executor.TF
    if FLAGS.fifo:
        logdir = FLAGS.save_dir / 'fifo'
    else:
        logdir = FLAGS.save_dir / ex.value
    # create workload instances
    workloads = load_trace(argv[0], ex)

    # Check and update if workloads have the info we need
    if ex == Executor.TF and not FLAGS.fifo:
        for w, _, _ in workloads:
            for field in ['peakmem']:
                if find_geometry(w, field) is None:
                    raise ValueError(f'Missing {field} data for workload {w.canonical_name} of {w.batch_num} iters'
                                     f', available geometries: {w.wtl._geometries}')

    # enable overcommit
    if FLAGS.overcommit > 1:
        for w, _, _ in workloads:
            w.env['TF_GPU_ALLOCATOR'] = 'cuda_managed'

    def accept_workload(w, alive):
        if FLAGS.fifo:
            return len(alive) == 0
        elif FLAGS.use_salus:
            return len(alive) < FLAGS.concurrent
        else:
            currmem = sum(wl.geometry.peakmem for wl in alive)
            return w.geometry.peakmem + currmem < FLAGS.overcommit * FLAGS.phymem

    try:
        try:
            with atomic_directory(logdir) as tmp:
                # copy trace file
                shutil.copy2(argv[0], str(tmp/'trace.csv'))

                with (tmp / 'exp15.output').open('w') as f:
                    started = []
                    pending = []
                    alive = []

                    def workload_done(proc):
                        w = proc.workload
                        logger.info(f'Finished workload {w.output_name}.{w.batch_num}iter.{w.job_id}')
                        print(f'{datetime.now()}: Finished workload '
                              f'{w.output_name}.{w.batch_num}iter.{w.job_id}',
                              file=f)

                    def do_stuff(rel_time):
                        if workloads:
                            w, submit_time, row = workloads[0]
                            if rel_time >= submit_time:
                                workloads.pop(0)
                                w.job_id = row["job_id"]
                                logger.info(f'Queued workload {w.output_name}.{w.batch_num}iter.{w.job_id}')
                                print(f'{datetime.now()}: Queued workload '
                                      f'{w.output_name}.{w.batch_num}iter.{w.job_id}',
                                      file=f)
                                pending.append(w)

                        _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done)

                        while pending and accept_workload(pending[0], alive):
                            w = pending.pop(0)

                            logger.info(f'Started workload {w.output_name}.{w.batch_num}iter.{w.job_id}')
                            print(f'{datetime.now()}: Started workload '
                                  f'{w.output_name}.{w.batch_num}iter.{w.job_id}',
                                  file=f)

                            output_file = tmp / f'{w.output_name}.{w.batch_num}iter.{w.job_id}.output'
                            w.run(output_file)
                            started.append(w)
                            alive.append(w)

                            _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done)

                        if not workloads and not pending:
                            _, alive[:] = SalusServer.wait_workloads(alive, callback=workload_done)
                            return False

                        return True

                    def event_loop():
                        # check every 0.1 second
                        interval = 0.1
                        origin = default_timer()
                        while True:
                            st = default_timer()
                            should_continue = do_stuff(st - origin)
                            if not should_continue:
                                break
                            ed = default_timer()

                            elispped = ed - st
                            time.sleep(interval - (elispped % interval))

                    if FLAGS.use_salus:
                        ss = SalusServer(scfg.copy(output_dir=logdir))
                        with ss.run():
                            event_loop()
                    else:
                        event_loop()

        except Exception as ex:
            logger.exception("Got exception when running workloads")
    finally:
        # if there's alive, we are doing cleanup
        for w, _, _ in workloads:
            if w.proc is not None and w.proc.poll() is None:
                logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}')
                kill_tree(w.proc, hard=True)

        # check each workloads and fix workload output_file path
        for w, _, _ in workloads:
            if not FLAGS.ignore_error and w.proc is not None and w.proc.returncode != 0:
                prompt.pause()
                raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}')
            if w.output_file is not None:
                w.output_file = logdir / w.output_file.name