def run_tfdist(output, *actions): # type: (Path, *TAction) -> List[Workload] """Run a sequence of actions""" workloads = [] # type: List[Workload] try: with atomic_directory(output) as temp_dir: # type: Path # start server ss = TFDistServer(outputdir=temp_dir) with ss.run(): # Do action specified in seq for act in actions: ss.check() if isinstance(act, Workload): if act.executor != Executor.TFDist: raise ValueError('run_tfdist can only run TFDist workloads') output_file = temp_dir / f'{act.output_name}.{act.batch_num}iter.{len(workloads)}.output' act.run(output_file) workloads.append(act) elif isinstance(act, (Pause, RunFn)): act.run(workloads, temp_dir=temp_dir) else: raise ValueError(f"Unexpected value `{act}' of {type(act)} passed to run_tfdist") logger.info(f'Waiting all workloads to finish') ss.wait_workloads(workloads) finally: # if there's alive, we are doing cleanup for w in workloads: if w.proc is not None and w.proc.poll() is None: logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}') kill_tree(w.proc, hard=True) # check each workloads and fix workload output_file path for w in workloads: if not FLAGS.ignore_error and w.proc.returncode != 0: prompt.pause() raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}') w.output_file = output return workloads
def main(argv): # type: (Sequence[str]) -> None scfg = maybe_forced_preset(presets.MostEfficient) scfg.scheduler = 'pack' ex = Executor.Salus if FLAGS.use_salus else Executor.TF if FLAGS.fifo: logdir = FLAGS.save_dir / 'fifo' else: logdir = FLAGS.save_dir / ex.value # create workload instances workloads = load_trace(argv[0], ex) # Check and update if workloads have the info we need if ex == Executor.TF and not FLAGS.fifo: for w, _, _ in workloads: for field in ['peakmem']: if find_geometry(w, field) is None: raise ValueError(f'Missing {field} data for workload {w.canonical_name} of {w.batch_num} iters' f', available geometries: {w.wtl._geometries}') # enable overcommit if FLAGS.overcommit > 1: for w, _, _ in workloads: w.env['TF_GPU_ALLOCATOR'] = 'cuda_managed' def accept_workload(w, alive): if FLAGS.fifo: return len(alive) == 0 elif FLAGS.use_salus: return len(alive) < FLAGS.concurrent else: currmem = sum(wl.geometry.peakmem for wl in alive) return w.geometry.peakmem + currmem < FLAGS.overcommit * FLAGS.phymem try: try: with atomic_directory(logdir) as tmp: # copy trace file shutil.copy2(argv[0], str(tmp/'trace.csv')) with (tmp / 'exp15.output').open('w') as f: started = [] pending = [] alive = [] def workload_done(proc): w = proc.workload logger.info(f'Finished workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Finished workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) def do_stuff(rel_time): if workloads: w, submit_time, row = workloads[0] if rel_time >= submit_time: workloads.pop(0) w.job_id = row["job_id"] logger.info(f'Queued workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Queued workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) pending.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) while pending and accept_workload(pending[0], alive): w = pending.pop(0) logger.info(f'Started workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Started workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) output_file = tmp / f'{w.output_name}.{w.batch_num}iter.{w.job_id}.output' w.run(output_file) started.append(w) alive.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) if not workloads and not pending: _, alive[:] = SalusServer.wait_workloads(alive, callback=workload_done) return False return True def event_loop(): # check every 0.1 second interval = 0.1 origin = default_timer() while True: st = default_timer() should_continue = do_stuff(st - origin) if not should_continue: break ed = default_timer() elispped = ed - st time.sleep(interval - (elispped % interval)) if FLAGS.use_salus: ss = SalusServer(scfg.copy(output_dir=logdir)) with ss.run(): event_loop() else: event_loop() except Exception as ex: logger.exception("Got exception when running workloads") finally: # if there's alive, we are doing cleanup for w, _, _ in workloads: if w.proc is not None and w.proc.poll() is None: logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}') kill_tree(w.proc, hard=True) # check each workloads and fix workload output_file path for w, _, _ in workloads: if not FLAGS.ignore_error and w.proc is not None and w.proc.returncode != 0: prompt.pause() raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}') if w.output_file is not None: w.output_file = logdir / w.output_file.name