コード例 #1
0
def get_golden_config(model_name, args):
    """Return a dict with the golden data for throughput and memory usage."""

    if model_name == "lm":
        return lm_wikitext2.get_golden_real_stats(False)
    else:
        raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
コード例 #2
0
def get_model_specs(model_name):
    """Return a dict with configurations required for configuring `model_name` model."""

    if model_name == "lm":
        return lm_wikitext2.get_model_config()
    elif model_name == "seq":
        return offload_seq.get_model_config()
    else:
        raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
コード例 #3
0
def create_benchmark_config(model_name):
    """Return a dict with configurations required for benchmarking `model_name` model."""

    if args.model_name == "lm":
        return lm_wikitext2.get_benchmark_config()
    elif args.model_name == "seq":
        return offload_seq.get_benchmark_config()
    else:
        raise RuntimeError(f"Unrecognized args.model_name {args.model_name}")