Example #1
0
def do_mem(logdir, network, batch_size):
    """Do basic JCT on workload"""
    batch_num = 20
    if network == "speech":
        batch_num = 5

    logger.info(f'Measuring memory for {network}_{batch_size} for {batch_num} iter')

    ex = "salus" if FLAGS.use_salus else "tf"
    final_dst = logdir / ex / WTL.from_name(network).canonical_name(RunConfig(batch_size, batch_num, None))
    with atomic_directory(final_dst) as outputdir:
        if not FLAGS.use_salus:
            logger.info('    Running on TF')
            wl = WTL.create(network, batch_size, batch_num, Executor.TF)
            wl.env['TF_CPP_MIN_VLOG_LEVEL'] = '1'
            wl.env['TF_CPP_MIN_LOG_LEVEL'] = ''
            run_tf(outputdir, wl)
            # filter and move file to a more convinent name
            for f in pathlib.Path(outputdir).iterdir():
                with f.with_name('alloc.output').open('w') as file:
                    grep = execute(['egrep', r"] (\+|-)", f.name], stdout=file, cwd=str(f.parent))
                    grep.wait()
                f.unlink()
                break
        else:
            scfg = maybe_forced_preset(presets.AllocProf)
            scfg.logconf = "memop"
            scfg.output_dir = outputdir
            server = SalusServer(scfg)
            with server.run():
                logger.info('    Running on Salus')
                WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output')

    return final_dst
Example #2
0
def do_jct(logdir, network, batch_size):
    """Do basic JCT on workload"""
    batch_num = 20

    final_dst = logdir / WTL.from_name(network).canonical_name(RunConfig(batch_size, batch_num, None))
    with atomic_directory(final_dst) as outputdir:
        logger.info(f'Measuring basic JCT for {batch_num} iterations')
        mps_name = '-mps' if FLAGS.is_mps else ''
        if not (final_dst/'gpu{}.output'.format(mps_name)).exists() or not FLAGS.resume:
            logger.info('    Running on TF')
            WTL.block_run(network, batch_size, batch_num, Executor.TF, outputdir / 'gpu{}.output'.format(mps_name))

        if FLAGS.do_tfdist:
            if not (final_dst/'tfdist{}.output'.format(mps_name)).exists() or not FLAGS.resume:
                with TFDistServer().run():
                    logger.info('    Running on TFDist')
                    WTL.block_run(network, batch_size, batch_num, Executor.TFDist, outputdir / 'tfdist{}.output'.format(mps_name))

        if FLAGS.is_mps:
            logger.info('    Skipping Salus jct when MPS is on')
            return final_dst

        if not (final_dst / 'rpc.output').exists() or not FLAGS.resume:
            scfg = maybe_forced_preset(presets.MostEfficient)
            scfg.output_dir = outputdir
            server = SalusServer(scfg)
            with server.run():
                logger.info('    Warming up Salus')
                # always use 20 batch num when warming up
                WTL.block_run(network, batch_size, 20, Executor.Salus, outputdir / 'rpc-warm.output')

                logger.info('    Running on Salus')
                WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output')

    return final_dst
Example #3
0
def do_jct_hint(logdir, network, batch_size, per_iter, target, tag):
    """Calculate JCT for target time"""
    final_dst = logdir / f"{network}_{batch_size}_{tag}"
    with atomic_directory(final_dst) as outputdir:
        if (final_dst / 'rpc.output').exists() and FLAGS.resume:
            per_iter = parse_output_float(final_dst / 'rpc.output', r'^Average excluding[^0-9.]+([0-9.]+).*')
            return per_iter

        logger.info(f"Finding suitable batch_num for {tag}")
        actual = 0
        chance = FLAGS.max_chance
        batch_num = int(target / per_iter)
        while chance > 0 and abs(actual - target) >= target * FLAGS.threshold:
            chance -= 1
            batch_num = int(target / per_iter)

            logger.info(f'    Trying batch_num={batch_num}')
            file = outputdir / 'gpu.output'
            WTL.block_run(network, batch_size, batch_num, Executor.TF, file)

            actual = parse_output_float(file, r'^JCT[^0-9.]+([0-9.]+).*')
            # assume linear time distribution
            per_iter = actual / batch_num
            logger.info(f"   actual_time={actual}, per_iter={per_iter}")

        # use the batch_num to run salus
        logger.info(f"Using batch_num={batch_num} for {tag}")
        scfg = maybe_forced_preset(presets.MostEfficient)
        scfg.output_dir = outputdir
        server = SalusServer(scfg)
        with server.run():
            logger.info('    Warming up Salus')
            WTL.block_run(network, batch_size, 20, Executor.Salus, outputdir / 'rpc-warm.output')

            logger.info('    Running on Salus')
            WTL.block_run(network, batch_size, batch_num, Executor.Salus, outputdir / 'rpc.output')
    return per_iter