Beispiel #1
0
 def test_adaptive_timer(self):
     # Validate both on different sizes validate against blocked_autorange
     # This looks for relative differences btetween orders of magnitude to
     # provide a stable/portable test which is somewhat informative.
     timer = benchmark_utils.Timer(stmt="torch.sum(torch.ones((10,10)))", )
     small = timer.adaptive_autorange(min_run_time=0.1, max_run_time=1.0)
     timer = benchmark_utils.Timer(
         stmt="torch.sum(torch.ones((500,500)))", )
     medium = timer.adaptive_autorange(min_run_time=0.1, max_run_time=1.0)
     blocked_medium = timer.blocked_autorange(min_run_time=0.1)
     self.assertLess(small.median, medium.median)
     # This acts as a control to compare to a different way to measure the same value.
     self.assertLess(small.median, blocked_medium.median)
Beispiel #2
0
 def test_compare(self):
     compare = benchmark_utils.Compare([
         benchmark_utils.Timer(
             "torch.ones((n,))", globals={"n": n},
             description="ones", label=str(n)).timeit(3)
         for n in range(3)
     ])
     compare.print()
Beispiel #3
0
    def test_timer(self):
        timer = benchmark_utils.Timer(stmt="torch.ones(())", )
        median = timer.blocked_autorange(min_run_time=0.01).median
        self.assertIsInstance(median, float)

        # We set a very high threshold to avoid flakiness in CI.
        # The internal algorithm is tested in `test_adaptive_timer`
        median = timer.adaptive_autorange(threshold=0.5).median
Beispiel #4
0
def main():
    tasks = [
        ("add", "add", "torch.add(x, y)"),
        ("add", "add (extra +0)", "torch.add(x, y + zero)"),
    ]

    serialized_results = []
    repeats = 2
    timers = [
        benchmark_utils.Timer(
            stmt=stmt,
            globals={
                "torch":
                torch if branch == "master" else FauxTorch(torch, overhead_ns),
                "x":
                torch.ones((size, 4)),
                "y":
                torch.ones((1, 4)),
                "zero":
                torch.zeros(()),
            },
            label=label,
            sub_label=sub_label,
            description=f"size: {size}",
            env=branch,
            num_threads=num_threads,
        ) for branch, overhead_ns in [("master",
                                       None), ("my_branch",
                                               1), ("severe_regression", 5)]
        for label, sub_label, stmt in tasks
        for size in [1, 10, 100, 1000, 10000, 50000] for num_threads in [1, 4]
    ]

    for i, timer in enumerate(timers * repeats):
        serialized_results.append(
            pickle.dumps(timer.blocked_autorange(min_run_time=0.05)))
        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
        sys.stdout.flush()
    print()

    comparison = benchmark_utils.Compare(
        [pickle.loads(i) for i in serialized_results])

    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
    comparison.print()

    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
    comparison.trim_significant_figures()
    comparison.colorize()
    comparison.print()
Beispiel #5
0
def main():
    timer = benchmark_utils.Timer(
        stmt="x + y",
        globals={
            "x": torch.ones((4, 8)),
            "y": torch.ones((1, 8))
        },
        label="Broadcasting add (4x8)",
    )

    for i in range(3):
        print(f"Run: {i}\n{'-' * 40}")
        print(f"timeit:\n{timer.timeit(10000)}\n")
        print(f"autorange:\n{timer.blocked_autorange()}\n\n")
Beispiel #6
0
def run_bench(model_names, bench_args):
    results = []
    for model_name in model_names:
        model_creator = MODELS[model_name]
        inputs, model = model_creator(bench_args)

        print("Benchmarking RecordFunction overhead for", model_name)
        print("Running warmup...", end=" ")
        sys.stdout.flush()
        for _ in range(bench_args.warmup):
            model(*inputs)
        print("finished")

        for num_threads in NUM_THREADS:
            for with_rec_fn in [True, False]:
                torch.autograd._enable_record_function(with_rec_fn)
                torch.autograd._clear_callbacks()
                if with_rec_fn:
                    torch.autograd._set_empty_test_observer(True, 0.0001)

                print("Running {} RecordFunction, num threads {} ...".format(
                    "with" if with_rec_fn else "without", num_threads),
                      end=" ")
                sys.stdout.flush()
                timer = benchmark_utils.Timer(
                    stmt="model(*inputs)",
                    globals={
                        "model": model,
                        "inputs": inputs
                    },
                    description=model_name,
                    label="Record function overhead",
                    sub_label=
                    f"with{'' if with_rec_fn else 'out'}_rec_fn, num_threads {num_threads}",
                    num_threads=num_threads)
                result = timer.blocked_autorange(
                    min_run_time=bench_args.timer_min_run_time)
                print("finished")
                print(result)
                sys.stdout.flush()
                results.append(result)

    comparison = benchmark_utils.Compare(results)
    comparison.trim_significant_figures()
    comparison.highlight_warnings()
    comparison.print()
Beispiel #7
0
 def test_timer(self):
     timer = benchmark_utils.Timer(
         stmt="torch.ones(())",
     )
     median = timer.blocked_autorange(min_run_time=0.1).median
     self.assertIsInstance(median, float)
Beispiel #8
0
def main():
    add_fuzzer = benchmark_utils.Fuzzer(
        parameters=[
            [
                benchmark_utils.FuzzedParameter(
                    name=f"k{i}",
                    minval=16,
                    maxval=16 * 1024,
                    distribution="loguniform",
                ) for i in range(3)
            ],
            benchmark_utils.FuzzedParameter(
                name="d",
                distribution={
                    2: 0.6,
                    3: 0.4
                },
            ),
        ],
        tensors=[
            [
                benchmark_utils.FuzzedTensor(
                    name=name,
                    size=("k0", "k1", "k2"),
                    dim_parameter="d",
                    probability_contiguous=0.75,
                    min_elements=64 * 1024,
                    max_elements=128 * 1024,
                ) for name in ("x", "y")
            ],
        ],
        seed=0,
    )

    n = 250
    measurements = []
    for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
        x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
        y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
        shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))

        description = "".join([
            f"{x.numel():>7} | {shape:<16} | ",
            f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
            f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
        ])

        timer = benchmark_utils.Timer(
            stmt="x + y",
            globals=tensors,
            description=description,
        )

        measurements.append(timer.blocked_autorange(min_run_time=0.1))
        measurements[-1].metadata = {"numel": x.numel()}
        print(f"\r{i + 1} / {n}", end="")
        sys.stdout.flush()
    print()

    # More string munging to make pretty output.
    print(
        f"Average attemts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}"
    )

    def time_fn(m):
        return m.median / m.metadata["numel"]

    measurements.sort(key=time_fn)

    template = f"{{:>6}}{' ' * 19}Size    Shape{' ' * 13}X order        Y order\n{'-' * 80}"
    print(template.format("Best:"))
    for m in measurements[:15]:
        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")

    print("\n" + template.format("Worst:"))
    for m in measurements[-15:]:
        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")