コード例 #1
0
ファイル: test_utils.py プロジェクト: waoup/pytorch
 def test_compare(self):
     compare = benchmark_utils.Compare([
         benchmark_utils.Timer(
             "torch.ones((n,))", globals={"n": n},
             description="ones", label=str(n)).timeit(3)
         for n in range(3)
     ])
     compare.print()
コード例 #2
0
def main():
    tasks = [
        ("add", "add", "torch.add(x, y)"),
        ("add", "add (extra +0)", "torch.add(x, y + zero)"),
    ]

    serialized_results = []
    repeats = 2
    timers = [
        benchmark_utils.Timer(
            stmt=stmt,
            globals={
                "torch":
                torch if branch == "master" else FauxTorch(torch, overhead_ns),
                "x":
                torch.ones((size, 4)),
                "y":
                torch.ones((1, 4)),
                "zero":
                torch.zeros(()),
            },
            label=label,
            sub_label=sub_label,
            description=f"size: {size}",
            env=branch,
            num_threads=num_threads,
        ) for branch, overhead_ns in [("master",
                                       None), ("my_branch",
                                               1), ("severe_regression", 5)]
        for label, sub_label, stmt in tasks
        for size in [1, 10, 100, 1000, 10000, 50000] for num_threads in [1, 4]
    ]

    for i, timer in enumerate(timers * repeats):
        serialized_results.append(
            pickle.dumps(timer.blocked_autorange(min_run_time=0.05)))
        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
        sys.stdout.flush()
    print()

    comparison = benchmark_utils.Compare(
        [pickle.loads(i) for i in serialized_results])

    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
    comparison.print()

    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
    comparison.trim_significant_figures()
    comparison.colorize()
    comparison.print()
コード例 #3
0
def run_bench(model_names, bench_args):
    results = []
    for model_name in model_names:
        model_creator = MODELS[model_name]
        inputs, model = model_creator(bench_args)

        print("Benchmarking RecordFunction overhead for", model_name)
        print("Running warmup...", end=" ")
        sys.stdout.flush()
        for _ in range(bench_args.warmup):
            model(*inputs)
        print("finished")

        for num_threads in NUM_THREADS:
            for with_rec_fn in [True, False]:
                torch.autograd._enable_record_function(with_rec_fn)
                torch.autograd._clear_callbacks()
                if with_rec_fn:
                    torch.autograd._set_empty_test_observer(True, 0.0001)

                print("Running {} RecordFunction, num threads {} ...".format(
                    "with" if with_rec_fn else "without", num_threads),
                      end=" ")
                sys.stdout.flush()
                timer = benchmark_utils.Timer(
                    stmt="model(*inputs)",
                    globals={
                        "model": model,
                        "inputs": inputs
                    },
                    description=model_name,
                    label="Record function overhead",
                    sub_label=
                    f"with{'' if with_rec_fn else 'out'}_rec_fn, num_threads {num_threads}",
                    num_threads=num_threads)
                result = timer.blocked_autorange(
                    min_run_time=bench_args.timer_min_run_time)
                print("finished")
                print(result)
                sys.stdout.flush()
                results.append(result)

    comparison = benchmark_utils.Compare(results)
    comparison.trim_significant_figures()
    comparison.highlight_warnings()
    comparison.print()
コード例 #4
0
    def test_compare(self):
        # Simulate several approaches.
        costs = (
            # overhead_optimized_fn()
            (1e-6, 1e-9),

            # compute_optimized_fn()
            (3e-6, 5e-10),

            # special_case_fn()  [square inputs only]
            (1e-6, 4e-10),
        )

        sizes = (
            (16, 16),
            (16, 128),
            (128, 128),
            (4096, 1024),
            (2048, 2048),
        )

        # overhead_optimized_fn()
        class _MockTimer_0(self._MockTimer):
            _function_costs = tuple(
                (f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j)
                for i, j in sizes)

        class MockTimer_0(benchmark_utils.Timer):
            _timer_cls = _MockTimer_0

        # compute_optimized_fn()
        class _MockTimer_1(self._MockTimer):
            _function_costs = tuple(
                (f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j)
                for i, j in sizes)

        class MockTimer_1(benchmark_utils.Timer):
            _timer_cls = _MockTimer_1

        # special_case_fn()
        class _MockTimer_2(self._MockTimer):
            _function_costs = tuple(
                (f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j)
                for i, j in sizes if i == j)

        class MockTimer_2(benchmark_utils.Timer):
            _timer_cls = _MockTimer_2

        results = []
        for i, j in sizes:
            results.append(
                MockTimer_0(
                    f"fn({i}, {j})",
                    label="fn",
                    description=f"({i}, {j})",
                    sub_label="overhead_optimized",
                ).blocked_autorange(min_run_time=10))

            results.append(
                MockTimer_1(
                    f"fn({i}, {j})",
                    label="fn",
                    description=f"({i}, {j})",
                    sub_label="compute_optimized",
                ).blocked_autorange(min_run_time=10))

            if i == j:
                results.append(
                    MockTimer_2(
                        f"fn({i}, {j})",
                        label="fn",
                        description=f"({i}, {j})",
                        sub_label="special_case (square)",
                    ).blocked_autorange(min_run_time=10))

        def check_output(output: str, expected: str):
            # VSCode will strip trailing newlines from `expected`, so we have to match
            # this behavior when comparing output.
            output_str = "\n".join(
                i.rstrip() for i in output.strip().splitlines(keepends=False))

            self.assertEqual(output_str, textwrap.dedent(expected).strip())

        compare = benchmark_utils.Compare(results)

        check_output(
            str(compare), """
            [------------------------------------------------- fn ------------------------------------------------]
                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
            1 threads: --------------------------------------------------------------------------------------------
                  overhead_optimized     |    1.3     |     3.0     |     17.4     |     4174.4     |     4174.4
                  compute_optimized      |    3.1     |     4.0     |     11.2     |     2099.3     |     2099.3
                  special_case (square)  |    1.1     |             |      7.5     |                |     1674.7

            Times are in microseconds (us).""")

        compare.trim_significant_figures()
        check_output(
            str(compare), """
            [------------------------------------------------- fn ------------------------------------------------]
                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
            1 threads: --------------------------------------------------------------------------------------------
                  overhead_optimized     |     1      |     3.0     |      17      |      4200      |      4200
                  compute_optimized      |     3      |     4.0     |      11      |      2100      |      2100
                  special_case (square)  |     1      |             |       8      |                |      1700

            Times are in microseconds (us).""")

        compare.colorize()
        check_output(str(compare), """
            [------------------------------------------------- fn ------------------------------------------------]
                                         |  (16, 16)  |  (16, 128)  |  (128, 128)  |  (4096, 1024)  |  (2048, 2048)
            1 threads: --------------------------------------------------------------------------------------------
                  overhead_optimized     |     1      |  \x1b[92m\x1b[1m   3.0   \x1b[0m\x1b[0m  |  \x1b[2m\x1b[91m    17    \x1b[0m\x1b[0m  |      4200      |  \x1b[2m\x1b[91m    4200    \x1b[0m\x1b[0m
                  compute_optimized      |  \x1b[2m\x1b[91m   3    \x1b[0m\x1b[0m  |     4.0     |      11      |  \x1b[92m\x1b[1m    2100    \x1b[0m\x1b[0m  |      2100
                  special_case (square)  |  \x1b[92m\x1b[1m   1    \x1b[0m\x1b[0m  |             |  \x1b[92m\x1b[1m     8    \x1b[0m\x1b[0m  |                |  \x1b[92m\x1b[1m    1700    \x1b[0m\x1b[0m

            Times are in microseconds (us)."""

                     # noqa
                     )