Exemple #1
0
    def test_fuse_linear_pattern_match(self, shape, out_features):
        input = torch.rand(shape)
        weight = torch.rand(out_features, shape[1])
        bias = torch.rand(out_features)

        def linear_addmm(input, weight, bias):
            return torch.addmm(bias, input, weight.t())

        def linear_matmul_add(input, weight, bias):
            output = input.matmul(weight.t())
            output += bias
            return output

        def linear_matmul(input, weight):
            return input.matmul(weight.t())

        import torch_tvm
        torch_tvm.enable()
        # test addmm
        scripted_addmm = torch.jit.script(linear_matmul_add)
        addmm_graph = scripted_addmm.graph_for(input, weight, bias)
        FileCheck().check("aten::linear").check_not("addmm").check_not("aten::t").run(str(addmm_graph))

        # test matmul + add
        scripted_matmul_add = torch.jit.script(linear_matmul_add)
        matmul_add_graph = scripted_matmul_add.graph_for(input, weight, bias)
        FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_add_graph))

        # test matmul
        scripted_matmul = torch.jit.script(linear_matmul)
        matmul_graph = scripted_matmul.graph_for(input, weight)
        FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_graph))
        torch_tvm.disable()
Exemple #2
0
    def test_dropout_removal(self, shape):
        input_a = torch.rand(shape)
        input_b = torch.rand(shape)
        input_c = torch.rand(shape)

        def dropout_training(a, b, c):
            t = a + b
            s = torch.dropout(t, 0.1, True)
            return s + c

        def dropout_inference(a, b, c):
            t = a + b
            s = torch.dropout(t, 0.1, False)
            return s + c

        torch_tvm.enable()
        tvm_graph_training = torch.jit.trace(dropout_training, \
                (input_a, input_b, input_c))
        tvm_graph_inference = torch.jit.trace(dropout_inference, \
                (input_a, input_b, input_c))
        torch_tvm.disable()
        assert "aten::dropout" in \
                str(tvm_graph_training.graph_for(input_a, input_b, input_c)), \
                "dropout must not be removed during training."
        assert "aten::dropout" not in \
                str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \
                "dropout must be removed during inference."
Exemple #3
0
    def runBoth(self, func, *inputs, check_tvm=True):
        with torch.no_grad():
            # jit the function
            trace_jit = torch.jit.trace(func, inputs)
            ref_out = trace_jit(*inputs)

            # jit the function and lower to TVM
            torch_tvm.enable()
            d = os.path.dirname(os.path.abspath(__file__))
            fn = os.path.join(d, "autotvm_tuning.log")

            with autotvm.apply_history_best(fn):
                trace_tvm = torch.jit.trace(func, inputs)
                try:
                    tvm_out = trace_tvm(*inputs)
                except Exception as e:
                    print("Error with graph\n{}".format(trace_tvm.graph))
                    raise e

            if check_tvm == True:
                tvm_unused = "TVM was not able to optimize this trace."
                assert "tvm::CompilationGroup" in str(
                    trace_tvm.graph_for(*inputs)
                ), tvm_unused + " Graph:\n" + str(trace_tvm.graph_for(*inputs))
                # tvm compile the graph and ensure TVM is used
                with profile() as p:
                    _ = trace_tvm(*inputs)
                assert "TVM" in [_.name for _ in p.function_events], tvm_unused

            torch_tvm.disable()

            return ref_out, tvm_out
Exemple #4
0
def benchmark(model, csv_file, input_fn=genImage, iters=100, warmup=10):
    with torch.no_grad():
        torch_tvm.disable()
        inputs = input_fn()
        print("Tracing model with JIT")
        trace_jit = torch.jit.trace(model, inputs)
        print("Warming JIT up with {} runs".format(warmup))
        for _ in range(warmup):
            _ = trace_jit(*inputs)

        print("Running JIT {} times".format(iters))
        start = time.time()
        for _ in range(iters):
            _ = trace_jit(*inputs)
        jit_time = time.time() - start
        print("Done benchmarking JIT")

        d = os.path.dirname(os.path.abspath(__file__))
        fn = os.path.join(d, "autotvm_tuning.log")
        with autotvm.apply_history_best(fn):
            torch_tvm.enable(opt_level=3,
                             device_type="cpu",
                             device="llvm -mcpu=core-avx2",
                             host="llvm -mcpu=core-avx2")
            print("Tracing model with TVM")
            trace_tvm = torch.jit.trace(model, inputs)
            print("Warming TVM up with {} iters".format(warmup))
            for _ in range(warmup):
                _ = trace_tvm(*inputs)

            print("Running TVM {} times".format(iters))
            start = time.time()
            for _ in range(iters):
                _ = trace_tvm(*inputs)
            tvm_time = time.time() - start
            with torch.autograd.profiler.profile() as prof:
                _ = trace_tvm(*inputs)
            tvm_profiled_time = 0
            total_profiled_time = 0
            for p in prof.key_averages():
                total_profiled_time += int(p.cpu_time)
                if p.key == "TVM":
                    tvm_profiled_time += int(p.cpu_time)

            print("Done benchmarking TVM, which compiled {:.2f}% of compute".
                  format(100 * tvm_profiled_time / total_profiled_time))
            if csv_file:
                exists = os.path.isfile(csv_file)
                with open(csv_file, 'a' if exists else 'w') as f:
                    if not exists:
                        f.write("timestamp,iter_per_sec\n")
                    f.write("{},{}\n".format(int(time.time()),
                                             iters / tvm_time))
        print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / jit_time,
                                                      iters / tvm_time))
Exemple #5
0
    def checkTraceTVM(self,
                      func,
                      input_tensors=None,
                      input_shapes=None,
                      size=100000,
                      runs=100,
                      verbose=False):
        # prepare inputs
        if input_tensors is None:
            if input_shapes is None:
                seed = torch.rand(size) / runs / 2
                input_tensors = (seed, seed, seed)
            else:
                input_tensors = []
                for shape in input_shapes:
                    seed = torch.rand(*shape) / runs / 2
                    input_tensors.append(seed)

        # jit the function
        trace_jit = torch.jit.trace(func, input_tensors)
        # specialize the graph with the inputs
        _ = trace_jit(*input_tensors)
        # timeit the perf
        jit_start = time.time()
        for _ in range(runs):
            outputs_jit = trace_jit(*input_tensors)
        jit_time = time.time() - jit_start

        # jit the function and lower to TVM
        torch_tvm.enable()
        trace_tvm = torch.jit.trace(func, input_tensors)
        tvm_unused = "TVM was not able to optimize this trace."
        assert "tvm::CompilationGroup" in str(
            trace_tvm.graph_for(*input_tensors)), tvm_unused
        # tvm compile the graph and ensure TVM is used
        with profile() as p:
            _ = trace_tvm(*input_tensors)
        assert "TVM" in [_.name for _ in p.function_events], tvm_unused
        torch_tvm.disable()
        # timeit the perf
        tvm_start = time.time()
        for _ in range(runs):
            outputs_tvm = trace_tvm(*input_tensors)
        tvm_time = time.time() - tvm_start

        if verbose:
            print("\noperator " + func.__name__ +
                  ":\t{} runs of size {}".format(runs, size) +
                  " \tjit time:{:.4f}s".format(jit_time) +
                  "\ttvm time:{:.4f}s".format(tvm_time))
        self.assertEqual(outputs_jit, outputs_tvm)
Exemple #6
0
def benmarch_pytorch_lstm(is_tvm=False, opt_level=0):
    torch_tvm.disable()
    # Ensure your model is in eval mode and also turn off gradients.
    with torch.no_grad():
        inputs = model_inputs()
        model = get_pytorch_model()
        if is_tvm:
            torch_tvm.enable(opt_level=opt_level,
                             device_type="cpu",
                             device="llvm -mcpu=core-avx2",
                             host="llvm -mcpu=core-avx2")
            # torch_tvm.enable(opt_level=opt_level)

        # This is where all the compilation happens.
        mod = torch.jit.trace(model, inputs)

        dry_run = 10  # use 10 iterations to warm up
        run = 100
        for i in range(dry_run + run):
            if i == dry_run:
                tic = time.time()
            _ = mod(inputs)
        time_iter = (time.time() - tic) * 1000 / run
        print(
            f"{get_benchmark_name(is_tvm, opt_level)}, timing: {time_iter} ms")

        if is_tvm:
            if "tvm::CompilationGroup" not in str(mod.graph_for(inputs)):
                print("TVM was not able to optimize this trace.")
            else:
                with torch.autograd.profiler.profile() as prof:
                    _ = mod(inputs)
                tvm_profiled_time = 0
                total_profiled_time = 0
                for p in prof.key_averages():
                    total_profiled_time += int(p.cpu_time)
                    if p.key == "TVM":
                        tvm_profiled_time += int(p.cpu_time)
                print("{} TVM compiling costs {:.2f}%".format(
                    get_benchmark_name(is_tvm, opt_level),
                    100 * tvm_profiled_time / total_profiled_time))
Exemple #7
0
    def test_fall_back(self, shape):
        inputs = torch.rand(shape)

        def add(input):
            return torch.add(input, 1, 2)

        jit_script_reshape = torch.jit.script(add)
        jit_out = jit_script_reshape(inputs)

        with self.assertRaises(RuntimeError):
            tvm_strict_script_reshape = torch.jit.script(add)
            torch_tvm.enable(strict=True)
            tvm_out = tvm_strict_script_reshape(inputs)
            torch_tvm.disable()

        torch_tvm.enable(strict=False)
        tvm_script_reshape = torch.jit.script(add)
        tvm_out = tvm_script_reshape(inputs)
        torch_tvm.disable()

        torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)
Exemple #8
0
    def runBoth(self, func, *inputs, check_tvm=True):
        # jit the function
        trace_jit = torch.jit.trace(func, inputs)
        ref_out = trace_jit(*inputs)

        # jit the function and lower to TVM
        torch_tvm.enable()
        trace_tvm = torch.jit.trace(func, inputs)
        tvm_out = trace_tvm(*inputs)

        if check_tvm == True:
            tvm_unused = "TVM was not able to optimize this trace."
            assert "tvm::CompilationGroup" in str(
                trace_tvm.graph_for(*inputs)
            ), tvm_unused
            # tvm compile the graph and ensure TVM is used
            with profile() as p:
                _ = trace_tvm(*inputs)
            assert "TVM" in [_.name for _ in p.function_events], tvm_unused

        torch_tvm.disable()

        return ref_out, tvm_out
Exemple #9
0
    def test_get_handle(self):
        shape = 8
        x = torch.rand(shape)
        y = torch.rand(shape)
        z = torch.rand(shape)

        def add(a, b, c):
            return a + b + c

        @torch.jit.script
        def mul(a, b, c):
            return a * b * c

        inputs = [x, y, z]

        torch_tvm.enable()

        trace_tvm = torch.jit.trace(add, inputs)

        relay_graph = torch_tvm.to_relay(trace_tvm, inputs)
        relay_graph = torch_tvm.to_relay(add, inputs)
        relay_graph = torch_tvm.to_relay(mul, inputs)

        torch_tvm.disable()
Exemple #10
0
    def test_core(self, shape):
        x = torch.rand(shape)
        y = torch.rand(shape)
        z = torch.rand(shape)

        def add(a, b, c):
            return a + b + c

        inputs = [x, y, z]

        trace_jit = torch.jit.trace(add, inputs)
        jit_out = trace_jit(*inputs)

        torch_tvm.enable()
        trace_tvm = torch.jit.trace(add, inputs)
        tvm_out = trace_tvm(*inputs)
        torch_tvm.disable()
        torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)

        torch_tvm.enable(opt_level=1)
        trace_tvm = torch.jit.trace(add, inputs)
        tvm_out = trace_tvm(*inputs)
        torch_tvm.disable()
        torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)

        torch_tvm.enable(opt_level=3)
        trace_tvm = torch.jit.trace(add, inputs)
        tvm_out = trace_tvm(*inputs)
        torch_tvm.disable()
        torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)

        torch_tvm.enable(device_type="cpu", device="llvm", host="llvm")
        trace_tvm = torch.jit.trace(add, inputs)
        tvm_out = trace_tvm(*inputs)
        torch_tvm.disable()
        torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)