def best_pytorch_func(data_dict): func_dict = {"fw":[],"bw_inp_g":[],"bw_wei_g":[]} for key, val in data_dict.items(): func_fw, func_inp_g, func_wei_g = best_functions(val.values()) func_dict['fw'].append(to_pytorch_func(func_fw)) func_dict['bw_inp_g'].append(to_pytorch_func(func_inp_g)) func_dict['bw_wei_g'].append(to_pytorch_func(func_wei_g)) return func_dict
def __init__(self): func_cpu, func_gpu = get_tvm_add() self.func = { 'mx': { 'cpu': to_mxnet_func(func_cpu, const_loc=[0, 1]), 'gpu': to_mxnet_func(func_gpu, const_loc=[0, 1]), }, 'th': { 'cpu': to_pytorch_func(func_cpu), 'gpu': to_pytorch_func(func_gpu), } }
def _get_function(dtype: str, device: str): '''Loads the function from the disk or compile it''' # A list of arguments that define the function args = (dtype, device) if args not in DiagonaledMM.function_dict: diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk if not diagonaled_mm: print('Tvm binary not found. Compiling ...') diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk # convert the tvm function into a pytorch function from tvm.contrib import dlpack diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function # save the function into a dictionary to be reused DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time return DiagonaledMM.function_dict[args]
def test(): a = np.random.randn(1337) tvm_a = tvm.nd.array(a) np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a) try: import torch import torch.utils.dlpack x = torch.rand(56, 56) tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x)) np.testing.assert_equal(x.numpy(), tvm_x.asnumpy()) y = tvm.nd.from_dlpack(tvm_x) np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy()) np.testing.assert_equal( torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()) n = tvm.runtime.convert(137) xx = torch.rand(137, 137) yy = torch.rand(137, 137) zz2 = torch.empty(137, 137) zz = xx.mm(yy) XX = te.placeholder((n, n), name="X") YY = te.placeholder((n, n), name="Y") k = te.reduce_axis((0, n), name="k") ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k)) s = te.create_schedule(ZZ.op) # No need to speficy target_host if it's llvm # Otherwise you will need to specify the target and target_host f = tvm.build(s, [XX, YY, ZZ], name="f") f_pytorch = to_pytorch_func(f) zz2 = torch.empty(137, 137) f_pytorch(xx, yy, zz2) tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-4, atol=1e-4) except ImportError: pass
def test(): a = np.random.randn(1337) tvm_a = tvm.nd.array(a) np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a) try: import torch import torch.utils.dlpack x = torch.rand(56, 56) tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x)) np.testing.assert_equal(x.numpy(), tvm_x.asnumpy()) y = tvm.nd.from_dlpack(tvm_x.to_dlpack()) np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy()) np.testing.assert_equal( torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()) n = tvm.convert(137) xx = torch.rand(137, 137) yy = torch.rand(137, 137) zz2 = torch.empty(137, 137) zz = xx.mm(yy) XX = tvm.placeholder((n, n), name='X') YY = tvm.placeholder((n, n), name='Y') k = tvm.reduce_axis((0, n), name='k') ZZ = tvm.compute((n, n), lambda i, j: tvm.sum(XX[i, k] * YY[k, j], axis=k)) s = tvm.create_schedule(ZZ.op) f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f') f_pytorch = to_pytorch_func(f) zz2 = torch.empty(137, 137) f_pytorch(xx, yy, zz2) np.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6) except ImportError: pass
def fb_max(size): #with autotvm.apply_history_best(f'best_hmm_k{size}.log'): with tvm.target.create("cuda"): s_mult, arg_bufs = hmm_runner_max('float32', size) from tvm.contrib.dlpack import to_pytorch_func mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm") hmm_pytorch_max = to_pytorch_func(mod) def fb(x): time, batch, size, _ = x.shape forward = torch.zeros(time + 1, batch, size).cuda() y = log_eye_cat(x).cuda() hmm_pytorch_max(y, forward) del y #torch.cuda.empty_cache() backward = torch.zeros(time + 1, batch, size).cuda() y = log_eye_cat(x.flip(0).transpose(-2, -1)).contiguous().cuda() hmm_pytorch_max(y, backward) del y #torch.cuda.empty_cache() #check = (forward.view(time+1, batch, size)+ # backward.flip(0).view(time+1, batch, size)) y = x.view(time, batch, size, size).transpose(-2, -1).contiguous().cuda() y += forward[:-1].view(time, batch, 1, size) y += backward[:-1].flip(0).view(time, batch, size, 1) marginals = y.transpose(-2, -1) #marginals = (forward[:-1].view(time, batch, 1, size) + # backward[:-1].flip(0).view(time, batch, size, 1) + # x.view(time, batch, size, size).transpose(-2, -1)).transpose(-2, -1) #return forward, backward, marginals, check return marginals return fb
def test(): a = np.random.randn(1337) tvm_a = tvm.nd.array(a) np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a) try: import torch import torch.utils.dlpack x = torch.rand(56, 56) tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x)) np.testing.assert_equal(x.numpy(), tvm_x.asnumpy()) y = tvm.nd.from_dlpack(tvm_x.to_dlpack()) np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy()) np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()) n = tvm.convert(137) xx = torch.rand(137,137) yy = torch.rand(137,137) zz2 = torch.empty(137,137) zz = xx.mm(yy) XX = tvm.placeholder((n,n), name='X') YY = tvm.placeholder((n,n), name='Y') k = tvm.reduce_axis((0, n), name='k') ZZ = tvm.compute((n,n), lambda i,j : tvm.sum(XX[i,k]*YY[k,j], axis=k)) s = tvm.create_schedule(ZZ.op) f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f') f_pytorch = to_pytorch_func(f) zz2 = torch.empty(137,137) f_pytorch(xx, yy, zz2) tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6) except ImportError: pass
cache_read(M, AA, AL, BB, BL, ko, ki) cache_read(M2, AA2, AL2, BB2, BL2, ko2, ki2) return s, [A, B, C] task = autotvm.task.create(logsummul, args=('float32', ), target='cuda', target_host="llvm") with autotvm.apply_history_best('best.log'): with tvm.target.create("cuda"): s_mult, arg_bufs = logsummul('float32') mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm") from tvm.contrib.dlpack import to_pytorch_func logsum_pytorch = to_pytorch_func(mod) if __name__ == "__main__": from tvm import autotvm task = autotvm.task.create(logsummul, args=('float32', ), target='cuda', target_host="llvm") print(task.config_space) measure_option = autotvm.measure_option( builder=autotvm.LocalBuilder(n_parallel=5), runner=autotvm.LocalRunner(number=10, repeat=3, timeout=10,
def get_fb(size): """ with autotvm.apply_history_best(f'best_hmm_k{size}.log'): with tvm.target.create("cuda"): s_mult, arg_bufs = hmm_runner('float32', size) mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm") hmm_pytorch = to_pytorch_func(mod) """ with tvm.target.create("cuda"): s_mult, arg_bufs = hmm_runner('float32', size) mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm") hmm_pytorch = to_pytorch_func(mod) # if the padding doesn't make a difference this must be an inclusive scan # x: batch x time x zt x zt-1 #@profile def fb(x, mask=None): batch, time, size, _ = x.shape lex = log_eye(size, dtype=x.dtype, device=x.device) # need time x batch x zt-1, zt x = x.permute(1, 0, 3, 2) if mask is not None: mask = mask.t() x[~mask[1:]] = lex """ x.masked_scatter_( ~mask[1:,:,None,None], # EXPAND HERE IS BAD? lex[None,None].expand(x.shape), ) import pdb; pdb.set_trace() """ """ out_fb = torch.empty(time+1, batch * 2, size, device=x.device) out_fb.fill_(float("-inf")) hmm_pytorch( log_eye_cat(torch.cat([x, x.flip(0).transpose(-2, -1)], 1)), out_fb, ) out_fb2 = torch.empty(time+1, batch * 2, size, device=x.device) out_fb2.fill_(float("-inf")) hmm_pytorch( log_eye_cat(x), out_fb2[:,:batch], ) hmm_pytorch( log_eye_cat(x.flip(0).transpose(-2, -1)), out_fb2[:,batch:], ) alphas = out_fb[:, :batch] betas = out_fb[:, batch:].flip(0) """ out_fb = torch.empty(2, time + 1, batch, size, device=x.device) out_fb.fill_(float("-inf")) inp = torch.empty(time + 1, batch, size, size, device=x.device) inp[-1] = lex # forward inp[:-1] = x hmm_pytorch(inp, out_fb[0]) # backward inp[range(time - 1, -1, -1)] = x.transpose(-2, -1) hmm_pytorch(inp, out_fb[1]) alphas = out_fb[0] betas = out_fb[1].flip(0) # pay the memory cost here # not sure if i can flip the argument to hmm_pytorch log_marginals = x log_marginals += alphas[:-1].view(time, batch, size, 1) log_marginals += betas[1:].view(time, batch, 1, size) log_marginals -= alphas[-1].logsumexp(-1).view(1, -1, 1, 1) if mask is not None: log_marginals.masked_fill_(~mask[1:, :, None, None], float("-inf")) log_marginals = log_marginals.permute(1, 0, 3, 2) return log_marginals, alphas #marginals = log_marginals.exp() # switch back marginals: batch x time x zt x zt-1 #return marginals, alphas, betas, log_marginals return fb