def test_compare(self): compare = benchmark_utils.Compare([ benchmark_utils.Timer( "torch.ones((n,))", globals={"n": n}, description="ones", label=str(n)).timeit(3) for n in range(3) ]) compare.print()
def main(): tasks = [ ("add", "add", "torch.add(x, y)"), ("add", "add (extra +0)", "torch.add(x, y + zero)"), ] serialized_results = [] repeats = 2 timers = [ benchmark_utils.Timer( stmt=stmt, globals={ "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), "x": torch.ones((size, 4)), "y": torch.ones((1, 4)), "zero": torch.zeros(()), }, label=label, sub_label=sub_label, description=f"size: {size}", env=branch, num_threads=num_threads, ) for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] for label, sub_label, stmt in tasks for size in [1, 10, 100, 1000, 10000, 50000] for num_threads in [1, 4] ] for i, timer in enumerate(timers * repeats): serialized_results.append( pickle.dumps(timer.blocked_autorange(min_run_time=0.05))) print(f"\r{i + 1} / {len(timers) * repeats}", end="") sys.stdout.flush() print() comparison = benchmark_utils.Compare( [pickle.loads(i) for i in serialized_results]) print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") comparison.print() print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") comparison.trim_significant_figures() comparison.colorize() comparison.print()
def run_bench(model_names, bench_args): results = [] for model_name in model_names: model_creator = MODELS[model_name] inputs, model = model_creator(bench_args) print("Benchmarking RecordFunction overhead for", model_name) print("Running warmup...", end=" ") sys.stdout.flush() for _ in range(bench_args.warmup): model(*inputs) print("finished") for num_threads in NUM_THREADS: for with_rec_fn in [True, False]: torch.autograd._enable_record_function(with_rec_fn) torch.autograd._clear_callbacks() if with_rec_fn: torch.autograd._set_empty_test_observer(True, 0.0001) print("Running {} RecordFunction, num threads {} ...".format( "with" if with_rec_fn else "without", num_threads), end=" ") sys.stdout.flush() timer = benchmark_utils.Timer( stmt="model(*inputs)", globals={ "model": model, "inputs": inputs }, description=model_name, label="Record function overhead", sub_label= f"with{'' if with_rec_fn else 'out'}_rec_fn, num_threads {num_threads}", num_threads=num_threads) result = timer.blocked_autorange( min_run_time=bench_args.timer_min_run_time) print("finished") print(result) sys.stdout.flush() results.append(result) comparison = benchmark_utils.Compare(results) comparison.trim_significant_figures() comparison.highlight_warnings() comparison.print()
def test_compare(self): # Simulate several approaches. costs = ( # overhead_optimized_fn() (1e-6, 1e-9), # compute_optimized_fn() (3e-6, 5e-10), # special_case_fn() [square inputs only] (1e-6, 4e-10), ) sizes = ( (16, 16), (16, 128), (128, 128), (4096, 1024), (2048, 2048), ) # overhead_optimized_fn() class _MockTimer_0(self._MockTimer): _function_costs = tuple( (f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j) for i, j in sizes) class MockTimer_0(benchmark_utils.Timer): _timer_cls = _MockTimer_0 # compute_optimized_fn() class _MockTimer_1(self._MockTimer): _function_costs = tuple( (f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j) for i, j in sizes) class MockTimer_1(benchmark_utils.Timer): _timer_cls = _MockTimer_1 # special_case_fn() class _MockTimer_2(self._MockTimer): _function_costs = tuple( (f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j) for i, j in sizes if i == j) class MockTimer_2(benchmark_utils.Timer): _timer_cls = _MockTimer_2 results = [] for i, j in sizes: results.append( MockTimer_0( f"fn({i}, {j})", label="fn", description=f"({i}, {j})", sub_label="overhead_optimized", ).blocked_autorange(min_run_time=10)) results.append( MockTimer_1( f"fn({i}, {j})", label="fn", description=f"({i}, {j})", sub_label="compute_optimized", ).blocked_autorange(min_run_time=10)) if i == j: results.append( MockTimer_2( f"fn({i}, {j})", label="fn", description=f"({i}, {j})", sub_label="special_case (square)", ).blocked_autorange(min_run_time=10)) def check_output(output: str, expected: str): # VSCode will strip trailing newlines from `expected`, so we have to match # this behavior when comparing output. output_str = "\n".join( i.rstrip() for i in output.strip().splitlines(keepends=False)) self.assertEqual(output_str, textwrap.dedent(expected).strip()) compare = benchmark_utils.Compare(results) check_output( str(compare), """ [------------------------------------------------- fn ------------------------------------------------] | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) 1 threads: -------------------------------------------------------------------------------------------- overhead_optimized | 1.3 | 3.0 | 17.4 | 4174.4 | 4174.4 compute_optimized | 3.1 | 4.0 | 11.2 | 2099.3 | 2099.3 special_case (square) | 1.1 | | 7.5 | | 1674.7 Times are in microseconds (us).""") compare.trim_significant_figures() check_output( str(compare), """ [------------------------------------------------- fn ------------------------------------------------] | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) 1 threads: -------------------------------------------------------------------------------------------- overhead_optimized | 1 | 3.0 | 17 | 4200 | 4200 compute_optimized | 3 | 4.0 | 11 | 2100 | 2100 special_case (square) | 1 | | 8 | | 1700 Times are in microseconds (us).""") compare.colorize() check_output(str(compare), """ [------------------------------------------------- fn ------------------------------------------------] | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) 1 threads: -------------------------------------------------------------------------------------------- overhead_optimized | 1 | \x1b[92m\x1b[1m 3.0 \x1b[0m\x1b[0m | \x1b[2m\x1b[91m 17 \x1b[0m\x1b[0m | 4200 | \x1b[2m\x1b[91m 4200 \x1b[0m\x1b[0m compute_optimized | \x1b[2m\x1b[91m 3 \x1b[0m\x1b[0m | 4.0 | 11 | \x1b[92m\x1b[1m 2100 \x1b[0m\x1b[0m | 2100 special_case (square) | \x1b[92m\x1b[1m 1 \x1b[0m\x1b[0m | | \x1b[92m\x1b[1m 8 \x1b[0m\x1b[0m | | \x1b[92m\x1b[1m 1700 \x1b[0m\x1b[0m Times are in microseconds (us).""" # noqa )