def test_fuse_linear_pattern_match(self, shape, out_features): input = torch.rand(shape) weight = torch.rand(out_features, shape[1]) bias = torch.rand(out_features) def linear_addmm(input, weight, bias): return torch.addmm(bias, input, weight.t()) def linear_matmul_add(input, weight, bias): output = input.matmul(weight.t()) output += bias return output def linear_matmul(input, weight): return input.matmul(weight.t()) import torch_tvm torch_tvm.enable() # test addmm scripted_addmm = torch.jit.script(linear_matmul_add) addmm_graph = scripted_addmm.graph_for(input, weight, bias) FileCheck().check("aten::linear").check_not("addmm").check_not("aten::t").run(str(addmm_graph)) # test matmul + add scripted_matmul_add = torch.jit.script(linear_matmul_add) matmul_add_graph = scripted_matmul_add.graph_for(input, weight, bias) FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_add_graph)) # test matmul scripted_matmul = torch.jit.script(linear_matmul) matmul_graph = scripted_matmul.graph_for(input, weight) FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_graph)) torch_tvm.disable()
def test_dropout_removal(self, shape): input_a = torch.rand(shape) input_b = torch.rand(shape) input_c = torch.rand(shape) def dropout_training(a, b, c): t = a + b s = torch.dropout(t, 0.1, True) return s + c def dropout_inference(a, b, c): t = a + b s = torch.dropout(t, 0.1, False) return s + c torch_tvm.enable() tvm_graph_training = torch.jit.trace(dropout_training, \ (input_a, input_b, input_c)) tvm_graph_inference = torch.jit.trace(dropout_inference, \ (input_a, input_b, input_c)) torch_tvm.disable() assert "aten::dropout" in \ str(tvm_graph_training.graph_for(input_a, input_b, input_c)), \ "dropout must not be removed during training." assert "aten::dropout" not in \ str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \ "dropout must be removed during inference."
def runBoth(self, func, *inputs, check_tvm=True): with torch.no_grad(): # jit the function trace_jit = torch.jit.trace(func, inputs) ref_out = trace_jit(*inputs) # jit the function and lower to TVM torch_tvm.enable() d = os.path.dirname(os.path.abspath(__file__)) fn = os.path.join(d, "autotvm_tuning.log") with autotvm.apply_history_best(fn): trace_tvm = torch.jit.trace(func, inputs) try: tvm_out = trace_tvm(*inputs) except Exception as e: print("Error with graph\n{}".format(trace_tvm.graph)) raise e if check_tvm == True: tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*inputs) ), tvm_unused + " Graph:\n" + str(trace_tvm.graph_for(*inputs)) # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*inputs) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() return ref_out, tvm_out
def benchmark(model, csv_file, input_fn=genImage, iters=100, warmup=10): with torch.no_grad(): torch_tvm.disable() inputs = input_fn() print("Tracing model with JIT") trace_jit = torch.jit.trace(model, inputs) print("Warming JIT up with {} runs".format(warmup)) for _ in range(warmup): _ = trace_jit(*inputs) print("Running JIT {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_jit(*inputs) jit_time = time.time() - start print("Done benchmarking JIT") d = os.path.dirname(os.path.abspath(__file__)) fn = os.path.join(d, "autotvm_tuning.log") with autotvm.apply_history_best(fn): torch_tvm.enable(opt_level=3, device_type="cpu", device="llvm -mcpu=core-avx2", host="llvm -mcpu=core-avx2") print("Tracing model with TVM") trace_tvm = torch.jit.trace(model, inputs) print("Warming TVM up with {} iters".format(warmup)) for _ in range(warmup): _ = trace_tvm(*inputs) print("Running TVM {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_tvm(*inputs) tvm_time = time.time() - start with torch.autograd.profiler.profile() as prof: _ = trace_tvm(*inputs) tvm_profiled_time = 0 total_profiled_time = 0 for p in prof.key_averages(): total_profiled_time += int(p.cpu_time) if p.key == "TVM": tvm_profiled_time += int(p.cpu_time) print("Done benchmarking TVM, which compiled {:.2f}% of compute". format(100 * tvm_profiled_time / total_profiled_time)) if csv_file: exists = os.path.isfile(csv_file) with open(csv_file, 'a' if exists else 'w') as f: if not exists: f.write("timestamp,iter_per_sec\n") f.write("{},{}\n".format(int(time.time()), iters / tvm_time)) print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / jit_time, iters / tvm_time))
def checkTraceTVM(self, func, input_tensors=None, input_shapes=None, size=100000, runs=100, verbose=False): # prepare inputs if input_tensors is None: if input_shapes is None: seed = torch.rand(size) / runs / 2 input_tensors = (seed, seed, seed) else: input_tensors = [] for shape in input_shapes: seed = torch.rand(*shape) / runs / 2 input_tensors.append(seed) # jit the function trace_jit = torch.jit.trace(func, input_tensors) # specialize the graph with the inputs _ = trace_jit(*input_tensors) # timeit the perf jit_start = time.time() for _ in range(runs): outputs_jit = trace_jit(*input_tensors) jit_time = time.time() - jit_start # jit the function and lower to TVM torch_tvm.enable() trace_tvm = torch.jit.trace(func, input_tensors) tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*input_tensors)), tvm_unused # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*input_tensors) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() # timeit the perf tvm_start = time.time() for _ in range(runs): outputs_tvm = trace_tvm(*input_tensors) tvm_time = time.time() - tvm_start if verbose: print("\noperator " + func.__name__ + ":\t{} runs of size {}".format(runs, size) + " \tjit time:{:.4f}s".format(jit_time) + "\ttvm time:{:.4f}s".format(tvm_time)) self.assertEqual(outputs_jit, outputs_tvm)
def benmarch_pytorch_lstm(is_tvm=False, opt_level=0): torch_tvm.disable() # Ensure your model is in eval mode and also turn off gradients. with torch.no_grad(): inputs = model_inputs() model = get_pytorch_model() if is_tvm: torch_tvm.enable(opt_level=opt_level, device_type="cpu", device="llvm -mcpu=core-avx2", host="llvm -mcpu=core-avx2") # torch_tvm.enable(opt_level=opt_level) # This is where all the compilation happens. mod = torch.jit.trace(model, inputs) dry_run = 10 # use 10 iterations to warm up run = 100 for i in range(dry_run + run): if i == dry_run: tic = time.time() _ = mod(inputs) time_iter = (time.time() - tic) * 1000 / run print( f"{get_benchmark_name(is_tvm, opt_level)}, timing: {time_iter} ms") if is_tvm: if "tvm::CompilationGroup" not in str(mod.graph_for(inputs)): print("TVM was not able to optimize this trace.") else: with torch.autograd.profiler.profile() as prof: _ = mod(inputs) tvm_profiled_time = 0 total_profiled_time = 0 for p in prof.key_averages(): total_profiled_time += int(p.cpu_time) if p.key == "TVM": tvm_profiled_time += int(p.cpu_time) print("{} TVM compiling costs {:.2f}%".format( get_benchmark_name(is_tvm, opt_level), 100 * tvm_profiled_time / total_profiled_time))
def test_fall_back(self, shape): inputs = torch.rand(shape) def add(input): return torch.add(input, 1, 2) jit_script_reshape = torch.jit.script(add) jit_out = jit_script_reshape(inputs) with self.assertRaises(RuntimeError): tvm_strict_script_reshape = torch.jit.script(add) torch_tvm.enable(strict=True) tvm_out = tvm_strict_script_reshape(inputs) torch_tvm.disable() torch_tvm.enable(strict=False) tvm_script_reshape = torch.jit.script(add) tvm_out = tvm_script_reshape(inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)
def runBoth(self, func, *inputs, check_tvm=True): # jit the function trace_jit = torch.jit.trace(func, inputs) ref_out = trace_jit(*inputs) # jit the function and lower to TVM torch_tvm.enable() trace_tvm = torch.jit.trace(func, inputs) tvm_out = trace_tvm(*inputs) if check_tvm == True: tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*inputs) ), tvm_unused # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*inputs) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() return ref_out, tvm_out
def test_get_handle(self): shape = 8 x = torch.rand(shape) y = torch.rand(shape) z = torch.rand(shape) def add(a, b, c): return a + b + c @torch.jit.script def mul(a, b, c): return a * b * c inputs = [x, y, z] torch_tvm.enable() trace_tvm = torch.jit.trace(add, inputs) relay_graph = torch_tvm.to_relay(trace_tvm, inputs) relay_graph = torch_tvm.to_relay(add, inputs) relay_graph = torch_tvm.to_relay(mul, inputs) torch_tvm.disable()
def test_core(self, shape): x = torch.rand(shape) y = torch.rand(shape) z = torch.rand(shape) def add(a, b, c): return a + b + c inputs = [x, y, z] trace_jit = torch.jit.trace(add, inputs) jit_out = trace_jit(*inputs) torch_tvm.enable() trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(opt_level=1) trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(opt_level=3) trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(device_type="cpu", device="llvm", host="llvm") trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)