def test_fuse_linear_pattern_match(self, shape, out_features): input = torch.rand(shape) weight = torch.rand(out_features, shape[1]) bias = torch.rand(out_features) def linear_addmm(input, weight, bias): return torch.addmm(bias, input, weight.t()) def linear_matmul_add(input, weight, bias): output = input.matmul(weight.t()) output += bias return output def linear_matmul(input, weight): return input.matmul(weight.t()) import torch_tvm torch_tvm.enable() # test addmm scripted_addmm = torch.jit.script(linear_matmul_add) addmm_graph = scripted_addmm.graph_for(input, weight, bias) FileCheck().check("aten::linear").check_not("addmm").check_not("aten::t").run(str(addmm_graph)) # test matmul + add scripted_matmul_add = torch.jit.script(linear_matmul_add) matmul_add_graph = scripted_matmul_add.graph_for(input, weight, bias) FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_add_graph)) # test matmul scripted_matmul = torch.jit.script(linear_matmul) matmul_graph = scripted_matmul.graph_for(input, weight) FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_graph)) torch_tvm.disable()
def __init__( self, eager_encoder: TransformerSentenceEncoderModule, tokens: torch.Tensor, segment_labels: torch.Tensor = None, positions: torch.Tensor = None, ) -> None: super().__init__() traceable_encoder = TraceableTransformerWrapper(eager_encoder) traced_encoder_inputs = self._prepare_inputs(tokens, segment_labels, positions) self.has_segment_labels = segment_labels is not None self.has_positions = positions is not None self.iter_ = 0 # do not check trace because of non-deterministic ops (e.g. dropout) self.traced_encoder = torch.jit.trace( traceable_encoder, tuple(traced_encoder_inputs), check_trace=False ) if torch.cuda.is_available(): try: import torch_tvm torch_tvm.enable( device_type="gpu", device="cuda", device_id=torch.cuda.current_device(), is_training=True, ) print("Using TVM in traced transformer") except ImportError: print("Not using TVM in traced transformer") log_class_usage(__class__)
def runBoth(self, func, *inputs, check_tvm=True): with torch.no_grad(): # jit the function trace_jit = torch.jit.trace(func, inputs) ref_out = trace_jit(*inputs) # jit the function and lower to TVM torch_tvm.enable() d = os.path.dirname(os.path.abspath(__file__)) fn = os.path.join(d, "autotvm_tuning.log") with autotvm.apply_history_best(fn): trace_tvm = torch.jit.trace(func, inputs) try: tvm_out = trace_tvm(*inputs) except Exception as e: print("Error with graph\n{}".format(trace_tvm.graph)) raise e if check_tvm == True: tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*inputs) ), tvm_unused + " Graph:\n" + str(trace_tvm.graph_for(*inputs)) # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*inputs) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() return ref_out, tvm_out
def test_dropout_removal(self, shape): input_a = torch.rand(shape) input_b = torch.rand(shape) input_c = torch.rand(shape) def dropout_training(a, b, c): t = a + b s = torch.dropout(t, 0.1, True) return s + c def dropout_inference(a, b, c): t = a + b s = torch.dropout(t, 0.1, False) return s + c torch_tvm.enable() tvm_graph_training = torch.jit.trace(dropout_training, \ (input_a, input_b, input_c)) tvm_graph_inference = torch.jit.trace(dropout_inference, \ (input_a, input_b, input_c)) torch_tvm.disable() assert "aten::dropout" in \ str(tvm_graph_training.graph_for(input_a, input_b, input_c)), \ "dropout must not be removed during training." assert "aten::dropout" not in \ str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \ "dropout must be removed during inference."
def main(): torch_tvm.enable(opt_level=3, device_type="gpu", device="cuda", host="llvm") path = argv[1] dr_config = Config(constants.DREAM_CONFIG) with open(path, 'rb') as f: dr_model = torch.load(f) bc = BasketConstructor(constants.RAW_DATA_DIR, constants.FEAT_DATA_DIR) ub_basket = bc.get_baskets('prior', reconstruct=False) # this function needs a sample input to infer types baskets, lens, users = Dataset(ub_basket)[0:dr_config.batch_size] baskets, lens, users = sort_batch_of_lists(baskets, lens, users) baskets = pad_batch_of_lists(baskets, lens[0]) dr_hidden = dr_model.init_hidden(dr_config.batch_size) ub_seqs = [] # users' basket sequence for ubaskets in baskets: x = dr_model.embed_baskets(ubaskets) ub_seqs.append(torch.cat(x, 0).unsqueeze(0)) ub_seqs = torch.cat(ub_seqs, 0) arg = [ub_seqs, dr_hidden] relay_graph = torch_tvm.to_relay(dr_model.rnn, arg)
def benchmark(model, input_fn=genImage, iters=100, warmup=10): with torch.no_grad(): inputs = input_fn() print("Tracing model with JIT") trace_jit = torch.jit.trace(model, inputs) print("Warming JIT up with {} runs".format(warmup)) for _ in range(warmup): _ = trace_jit(*inputs) print("Running JIT {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_jit(*inputs) jit_time = time.time() - start print("Done benchmarking JIT") with autotvm.apply_history_best("test/autotvm_tuning.log"): torch_tvm.enable() print("Tracing model with TVM") trace_tvm = torch.jit.trace(model, inputs) print("Warming TVM up with {} iters".format(warmup)) for _ in range(warmup): _ = trace_tvm(*inputs) print("Running TVM {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_tvm(*inputs) tvm_time = time.time() - start print("Done benchmarking TVM") print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / jit_time, iters / tvm_time))
def benchmark(model, csv_file, input_fn=genImage, iters=100, warmup=10): with torch.no_grad(): torch_tvm.disable() inputs = input_fn() print("Tracing model with JIT") trace_jit = torch.jit.trace(model, inputs) print("Warming JIT up with {} runs".format(warmup)) for _ in range(warmup): _ = trace_jit(*inputs) print("Running JIT {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_jit(*inputs) jit_time = time.time() - start print("Done benchmarking JIT") d = os.path.dirname(os.path.abspath(__file__)) fn = os.path.join(d, "autotvm_tuning.log") with autotvm.apply_history_best(fn): torch_tvm.enable(opt_level=3, device_type="cpu", device="llvm -mcpu=core-avx2", host="llvm -mcpu=core-avx2") print("Tracing model with TVM") trace_tvm = torch.jit.trace(model, inputs) print("Warming TVM up with {} iters".format(warmup)) for _ in range(warmup): _ = trace_tvm(*inputs) print("Running TVM {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_tvm(*inputs) tvm_time = time.time() - start with torch.autograd.profiler.profile() as prof: _ = trace_tvm(*inputs) tvm_profiled_time = 0 total_profiled_time = 0 for p in prof.key_averages(): total_profiled_time += int(p.cpu_time) if p.key == "TVM": tvm_profiled_time += int(p.cpu_time) print("Done benchmarking TVM, which compiled {:.2f}% of compute". format(100 * tvm_profiled_time / total_profiled_time)) if csv_file: exists = os.path.isfile(csv_file) with open(csv_file, 'a' if exists else 'w') as f: if not exists: f.write("timestamp,iter_per_sec\n") f.write("{},{}\n".format(int(time.time()), iters / tvm_time)) print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / jit_time, iters / tvm_time))
def test_concat_fuse(self, shape): input1 = torch.rand(shape) input2 = torch.rand(shape) def concat(x1, x2): return torch.cat((x1, x2), 0) import torch_tvm torch_tvm.enable() # test concat scripted_concat = torch.jit.script(concat) concat_graph = scripted_concat.graph_for(input1, input2) FileCheck().check("prim::FusedConcat").check_not("prim::ListConstruct").check_not("aten::cat").run(str(concat_graph))
def checkTraceTVM(self, func, input_tensors=None, input_shapes=None, size=100000, runs=100, verbose=False): # prepare inputs if input_tensors is None: if input_shapes is None: seed = torch.rand(size) / runs / 2 input_tensors = (seed, seed, seed) else: input_tensors = [] for shape in input_shapes: seed = torch.rand(*shape) / runs / 2 input_tensors.append(seed) # jit the function trace_jit = torch.jit.trace(func, input_tensors) # specialize the graph with the inputs _ = trace_jit(*input_tensors) # timeit the perf jit_start = time.time() for _ in range(runs): outputs_jit = trace_jit(*input_tensors) jit_time = time.time() - jit_start # jit the function and lower to TVM torch_tvm.enable() trace_tvm = torch.jit.trace(func, input_tensors) tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*input_tensors)), tvm_unused # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*input_tensors) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() # timeit the perf tvm_start = time.time() for _ in range(runs): outputs_tvm = trace_tvm(*input_tensors) tvm_time = time.time() - tvm_start if verbose: print("\noperator " + func.__name__ + ":\t{} runs of size {}".format(runs, size) + " \tjit time:{:.4f}s".format(jit_time) + "\ttvm time:{:.4f}s".format(tvm_time)) self.assertEqual(outputs_jit, outputs_tvm)
def benmarch_pytorch_lstm(is_tvm=False, opt_level=0): torch_tvm.disable() # Ensure your model is in eval mode and also turn off gradients. with torch.no_grad(): inputs = model_inputs() model = get_pytorch_model() if is_tvm: torch_tvm.enable(opt_level=opt_level, device_type="cpu", device="llvm -mcpu=core-avx2", host="llvm -mcpu=core-avx2") # torch_tvm.enable(opt_level=opt_level) # This is where all the compilation happens. mod = torch.jit.trace(model, inputs) dry_run = 10 # use 10 iterations to warm up run = 100 for i in range(dry_run + run): if i == dry_run: tic = time.time() _ = mod(inputs) time_iter = (time.time() - tic) * 1000 / run print( f"{get_benchmark_name(is_tvm, opt_level)}, timing: {time_iter} ms") if is_tvm: if "tvm::CompilationGroup" not in str(mod.graph_for(inputs)): print("TVM was not able to optimize this trace.") else: with torch.autograd.profiler.profile() as prof: _ = mod(inputs) tvm_profiled_time = 0 total_profiled_time = 0 for p in prof.key_averages(): total_profiled_time += int(p.cpu_time) if p.key == "TVM": tvm_profiled_time += int(p.cpu_time) print("{} TVM compiling costs {:.2f}%".format( get_benchmark_name(is_tvm, opt_level), 100 * tvm_profiled_time / total_profiled_time))
def test_fall_back(self, shape): inputs = torch.rand(shape) def add(input): return torch.add(input, 1, 2) jit_script_reshape = torch.jit.script(add) jit_out = jit_script_reshape(inputs) with self.assertRaises(RuntimeError): tvm_strict_script_reshape = torch.jit.script(add) torch_tvm.enable(strict=True) tvm_out = tvm_strict_script_reshape(inputs) torch_tvm.disable() torch_tvm.enable(strict=False) tvm_script_reshape = torch.jit.script(add) tvm_out = tvm_script_reshape(inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)
def benchmark(model, csv_file, input_fn=genImage, iters=100, warmup=10): with torch.no_grad(): inputs = input_fn() print("Tracing model with JIT") trace_jit = torch.jit.trace(model, inputs) print("Warming JIT up with {} runs".format(warmup)) for _ in range(warmup): _ = trace_jit(*inputs) print("Running JIT {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_jit(*inputs) jit_time = time.time() - start print("Done benchmarking JIT") with autotvm.apply_history_best("test/autotvm_tuning.log"): torch_tvm.enable() print("Tracing model with TVM") trace_tvm = torch.jit.trace(model, inputs) print("Warming TVM up with {} iters".format(warmup)) for _ in range(warmup): _ = trace_tvm(*inputs) print("Running TVM {} times".format(iters)) start = time.time() for _ in range(iters): _ = trace_tvm(*inputs) tvm_time = time.time() - start print("Done benchmarking TVM") if csv_file: exists = os.path.isfile(csv_file) with open(csv_file, 'a' if exists else 'w') as f: if not exists: f.write("timestamp,iter_per_sec\n") f.write("{},{}\n".format(int(time.time()), iters / tvm_time)) print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / jit_time, iters / tvm_time))
def runBoth(self, func, *inputs, check_tvm=True): # jit the function trace_jit = torch.jit.trace(func, inputs) ref_out = trace_jit(*inputs) # jit the function and lower to TVM torch_tvm.enable() trace_tvm = torch.jit.trace(func, inputs) tvm_out = trace_tvm(*inputs) if check_tvm == True: tvm_unused = "TVM was not able to optimize this trace." assert "tvm::CompilationGroup" in str( trace_tvm.graph_for(*inputs) ), tvm_unused # tvm compile the graph and ensure TVM is used with profile() as p: _ = trace_tvm(*inputs) assert "TVM" in [_.name for _ in p.function_events], tvm_unused torch_tvm.disable() return ref_out, tvm_out
def test_get_handle(self): shape = 8 x = torch.rand(shape) y = torch.rand(shape) z = torch.rand(shape) def add(a, b, c): return a + b + c @torch.jit.script def mul(a, b, c): return a * b * c inputs = [x, y, z] torch_tvm.enable() trace_tvm = torch.jit.trace(add, inputs) relay_graph = torch_tvm.to_relay(trace_tvm, inputs) relay_graph = torch_tvm.to_relay(add, inputs) relay_graph = torch_tvm.to_relay(mul, inputs) torch_tvm.disable()
def test_core(self, shape): x = torch.rand(shape) y = torch.rand(shape) z = torch.rand(shape) def add(a, b, c): return a + b + c inputs = [x, y, z] trace_jit = torch.jit.trace(add, inputs) jit_out = trace_jit(*inputs) torch_tvm.enable() trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(opt_level=1) trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(opt_level=3) trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) torch_tvm.enable(device_type="cpu", device="llvm", host="llvm") trace_tvm = torch.jit.trace(add, inputs) tvm_out = trace_tvm(*inputs) torch_tvm.disable() torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01)
def main(): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = constants.GPUS torch_tvm.enable(opt_level=3, device_type="gpu", device="cuda", host="llvm") # Prepare input bc = BasketConstructor(constants.RAW_DATA_DIR, constants.FEAT_DATA_DIR) # Users' baskets ub_basket = bc.get_baskets('prior', reconstruct=False) if constants.REORDER: # Users' reordered baskets ub_rbks = bc.get_baskets('prior', reconstruct=False, reordered=True) # User's item history ub_ihis = bc.get_item_history('prior', reconstruct=False) # Train test split train_ub, test_ub, train_rbks, test_rbks, train_ihis, test_ihis = train_test_split( ub_basket, ub_rbks, ub_ihis, test_size=0.2) del ub_basket, ub_rbks, ub_ihis # memory saving train_ub, test_ub = Dataset(train_ub, train_rbks, train_ihis), Dataset( test_ub, test_rbks, test_ihis) del train_rbks, test_rbks, train_ihis, test_ihis # memory saving else: train_ub, test_ub = train_test_split(ub_basket, test_size=0.2) del ub_basket train_ub, test_ub = Dataset(train_ub), Dataset(test_ub) # Model config dr_config = Config(constants.DREAM_CONFIG) dr_model = DreamModel(dr_config) if dr_config.cuda: dr_model.cuda() # Optimizer optim = torch.optim.Adam(dr_model.parameters(), lr=dr_config.learning_rate) # optim = torch.optim.Adadelta(dr_model.parameters()) # optim = torch.optim.SGD(dr_model.parameters(), lr=dr_config.learning_rate, momentum=0.9) writer = SummaryWriter(log_dir='runs/{}'.format( dr_config.alias)) # tensorboard writer writer.add_text('config', str(dr_config)) best_val_loss = None try: for k, v in constants.DREAM_CONFIG.items(): print(k, v) # training for epoch in range(dr_config.epochs): if constants.REORDER: train_reorder_dream() else: train_dream() print('-' * 89) if constants.REORDER: val_loss = evaluate_reorder_dream() else: val_loss = evaluate_dream() print('-' * 89) # checkpoint if not best_val_loss or val_loss < best_val_loss: with open( dr_config.checkpoint_dir.format(epoch=epoch, loss=val_loss), 'wb') as f: torch.save(dr_model, f) best_val_loss = val_loss else: # Manual SGD slow down lr if no improvement in val_loss # dr_config.learning_rate = dr_config.learning_rate / 4 pass except KeyboardInterrupt: print('*' * 89) print('Got keyboard Interrupt and stopped early')