def do_mem(logdir, network, batch_size): """Do basic JCT on workload""" batch_num = 20 if network == "speech": batch_num = 5 logger.info(f'Measuring memory for {network}_{batch_size} for {batch_num} iter') ex = "salus" if FLAGS.use_salus else "tf" final_dst = logdir / ex / WTL.from_name(network).canonical_name(RunConfig(batch_size, batch_num, None)) with atomic_directory(final_dst) as outputdir: if not FLAGS.use_salus: logger.info(' Running on TF') wl = WTL.create(network, batch_size, batch_num, Executor.TF) wl.env['TF_CPP_MIN_VLOG_LEVEL'] = '1' wl.env['TF_CPP_MIN_LOG_LEVEL'] = '' run_tf(outputdir, wl) # filter and move file to a more convinent name for f in pathlib.Path(outputdir).iterdir(): with f.with_name('alloc.output').open('w') as file: grep = execute(['egrep', r"] (\+|-)", f.name], stdout=file, cwd=str(f.parent)) grep.wait() f.unlink() break else: scfg = maybe_forced_preset(presets.AllocProf) scfg.logconf = "memop" scfg.output_dir = outputdir server = SalusServer(scfg) with server.run(): logger.info(' Running on Salus') WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output') return final_dst
def do_stuff(rel_time): if workloads: w, submit_time, row = workloads[0] if rel_time >= submit_time: workloads.pop(0) w.job_id = row["job_id"] logger.info(f'Queued workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Queued workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) pending.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) while pending and accept_workload(pending[0], alive): w = pending.pop(0) logger.info(f'Started workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Started workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) output_file = tmp / f'{w.output_name}.{w.batch_num}iter.{w.job_id}.output' w.run(output_file) started.append(w) alive.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) if not workloads and not pending: _, alive[:] = SalusServer.wait_workloads(alive, callback=workload_done) return False return True
def limit_concurrent(wls): # type: (Iterable[Workload]) -> None """Wait for something to finish""" gone, alive = SalusServer.wait_workloads(wls, timeout=0) while len(alive) >= FLAGS.concurrent_jobs: gone, alive = SalusServer.wait_workloads(wls, timeout=0) time.sleep(.25)
def do_jct(logdir, network, batch_size): """Do basic JCT on workload""" batch_num = 20 final_dst = logdir / WTL.from_name(network).canonical_name(RunConfig(batch_size, batch_num, None)) with atomic_directory(final_dst) as outputdir: logger.info(f'Measuring basic JCT for {batch_num} iterations') mps_name = '-mps' if FLAGS.is_mps else '' if not (final_dst/'gpu{}.output'.format(mps_name)).exists() or not FLAGS.resume: logger.info(' Running on TF') WTL.block_run(network, batch_size, batch_num, Executor.TF, outputdir / 'gpu{}.output'.format(mps_name)) if FLAGS.do_tfdist: if not (final_dst/'tfdist{}.output'.format(mps_name)).exists() or not FLAGS.resume: with TFDistServer().run(): logger.info(' Running on TFDist') WTL.block_run(network, batch_size, batch_num, Executor.TFDist, outputdir / 'tfdist{}.output'.format(mps_name)) if FLAGS.is_mps: logger.info(' Skipping Salus jct when MPS is on') return final_dst if not (final_dst / 'rpc.output').exists() or not FLAGS.resume: scfg = maybe_forced_preset(presets.MostEfficient) scfg.output_dir = outputdir server = SalusServer(scfg) with server.run(): logger.info(' Warming up Salus') # always use 20 batch num when warming up WTL.block_run(network, batch_size, 20, Executor.Salus, outputdir / 'rpc-warm.output') logger.info(' Running on Salus') WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output') return final_dst
def run(self, workloads, **kwargs): if self == Pause.Manual: prompt.pause() elif self == Pause.Wait: logger.info(f"Waiting current {len(workloads)} workloads to finish") SalusServer.wait_workloads(workloads) else: logger.info(f"Sleep {self} seconds") time.sleep(self)
def run_tf(output_dir, *actions): # type: (Path, *TAction) -> List[Workload] """Run a sequence of actions""" workloads = [] # type: List[Workload] try: with atomic_directory(output_dir) as temp_dir: # type: Path # Do action specified in seq for act in actions: if isinstance(act, Workload): if act.executor != Executor.TF: raise ValueError('run_tf can only run TF workloads') output_file = temp_dir / f'{act.output_name}.{act.batch_num}iter.{len(workloads)}.output' act.run(output_file) workloads.append(act) elif isinstance(act, (Pause, RunFn)): act.run(workloads, temp_dir=temp_dir) else: raise ValueError(f"Unexpected value `{act}' of {type(act)} passed to run_seq") logger.info(f'Waiting all workloads to finish') SalusServer.wait_workloads(workloads) except Exception: logger.exception("Got exception when running workloads") finally: # if there's alive, we are doing cleanup for w in workloads: if w.proc is not None and w.proc.poll() is None: logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}') kill_tree(w.proc, hard=True) # check each workloads and fix workload output_file path for w in workloads: if not FLAGS.ignore_error and w.proc.returncode != 0: prompt.pause() raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}') w.output_file = output_dir / w.output_file.name return workloads
def do_jct_hint(logdir, network, batch_size, per_iter, target, tag): """Calculate JCT for target time""" final_dst = logdir / f"{network}_{batch_size}_{tag}" with atomic_directory(final_dst) as outputdir: if (final_dst / 'rpc.output').exists() and FLAGS.resume: per_iter = parse_output_float(final_dst / 'rpc.output', r'^Average excluding[^0-9.]+([0-9.]+).*') return per_iter logger.info(f"Finding suitable batch_num for {tag}") actual = 0 chance = FLAGS.max_chance batch_num = int(target / per_iter) while chance > 0 and abs(actual - target) >= target * FLAGS.threshold: chance -= 1 batch_num = int(target / per_iter) logger.info(f' Trying batch_num={batch_num}') file = outputdir / 'gpu.output' WTL.block_run(network, batch_size, batch_num, Executor.TF, file) actual = parse_output_float(file, r'^JCT[^0-9.]+([0-9.]+).*') # assume linear time distribution per_iter = actual / batch_num logger.info(f" actual_time={actual}, per_iter={per_iter}") # use the batch_num to run salus logger.info(f"Using batch_num={batch_num} for {tag}") scfg = maybe_forced_preset(presets.MostEfficient) scfg.output_dir = outputdir server = SalusServer(scfg) with server.run(): logger.info(' Warming up Salus') WTL.block_run(network, batch_size, 20, Executor.Salus, outputdir / 'rpc-warm.output') logger.info(' Running on Salus') WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output') return per_iter
def main(argv): # type: (Sequence[str]) -> None scfg = maybe_forced_preset(presets.MostEfficient) scfg.scheduler = 'pack' ex = Executor.Salus if FLAGS.use_salus else Executor.TF if FLAGS.fifo: logdir = FLAGS.save_dir / 'fifo' else: logdir = FLAGS.save_dir / ex.value # create workload instances workloads = load_trace(argv[0], ex) # Check and update if workloads have the info we need if ex == Executor.TF and not FLAGS.fifo: for w, _, _ in workloads: for field in ['peakmem']: if find_geometry(w, field) is None: raise ValueError(f'Missing {field} data for workload {w.canonical_name} of {w.batch_num} iters' f', available geometries: {w.wtl._geometries}') # enable overcommit if FLAGS.overcommit > 1: for w, _, _ in workloads: w.env['TF_GPU_ALLOCATOR'] = 'cuda_managed' def accept_workload(w, alive): if FLAGS.fifo: return len(alive) == 0 elif FLAGS.use_salus: return len(alive) < FLAGS.concurrent else: currmem = sum(wl.geometry.peakmem for wl in alive) return w.geometry.peakmem + currmem < FLAGS.overcommit * FLAGS.phymem try: try: with atomic_directory(logdir) as tmp: # copy trace file shutil.copy2(argv[0], str(tmp/'trace.csv')) with (tmp / 'exp15.output').open('w') as f: started = [] pending = [] alive = [] def workload_done(proc): w = proc.workload logger.info(f'Finished workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Finished workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) def do_stuff(rel_time): if workloads: w, submit_time, row = workloads[0] if rel_time >= submit_time: workloads.pop(0) w.job_id = row["job_id"] logger.info(f'Queued workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Queued workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) pending.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) while pending and accept_workload(pending[0], alive): w = pending.pop(0) logger.info(f'Started workload {w.output_name}.{w.batch_num}iter.{w.job_id}') print(f'{datetime.now()}: Started workload ' f'{w.output_name}.{w.batch_num}iter.{w.job_id}', file=f) output_file = tmp / f'{w.output_name}.{w.batch_num}iter.{w.job_id}.output' w.run(output_file) started.append(w) alive.append(w) _, alive[:] = SalusServer.wait_workloads(alive, timeout=0, callback=workload_done) if not workloads and not pending: _, alive[:] = SalusServer.wait_workloads(alive, callback=workload_done) return False return True def event_loop(): # check every 0.1 second interval = 0.1 origin = default_timer() while True: st = default_timer() should_continue = do_stuff(st - origin) if not should_continue: break ed = default_timer() elispped = ed - st time.sleep(interval - (elispped % interval)) if FLAGS.use_salus: ss = SalusServer(scfg.copy(output_dir=logdir)) with ss.run(): event_loop() else: event_loop() except Exception as ex: logger.exception("Got exception when running workloads") finally: # if there's alive, we are doing cleanup for w, _, _ in workloads: if w.proc is not None and w.proc.poll() is None: logger.warning(f'Killing workload that is not stopped yet: {w.canonical_name}') kill_tree(w.proc, hard=True) # check each workloads and fix workload output_file path for w, _, _ in workloads: if not FLAGS.ignore_error and w.proc is not None and w.proc.returncode != 0: prompt.pause() raise RuntimeError(f'Workload {w.canonical_name} did not finish cleanly: {w.proc.returncode}') if w.output_file is not None: w.output_file = logdir / w.output_file.name