コード例 #1
0
ファイル: load_func_map.py プロジェクト: WeihaoZhuang/fast3d
def best_pytorch_func(data_dict):
    func_dict = {"fw":[],"bw_inp_g":[],"bw_wei_g":[]}   

    for key, val in data_dict.items():
        func_fw, func_inp_g, func_wei_g = best_functions(val.values())

        func_dict['fw'].append(to_pytorch_func(func_fw))
        func_dict['bw_inp_g'].append(to_pytorch_func(func_inp_g))
        func_dict['bw_wei_g'].append(to_pytorch_func(func_wei_g))
    return func_dict
コード例 #2
0
    def __init__(self):
        func_cpu, func_gpu = get_tvm_add()

        self.func = {
            'mx': {
                'cpu': to_mxnet_func(func_cpu, const_loc=[0, 1]),
                'gpu': to_mxnet_func(func_gpu, const_loc=[0, 1]),
            },
            'th': {
                'cpu': to_pytorch_func(func_cpu),
                'gpu': to_pytorch_func(func_gpu),
            }
        }
コード例 #3
0
 def _get_function(dtype: str, device: str):
     '''Loads the function from the disk or compile it'''
     # A list of arguments that define the function
     args = (dtype, device)
     if args not in DiagonaledMM.function_dict:
         diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device)  # try to load from disk
         if not diagonaled_mm:
             print('Tvm binary not found. Compiling ...')
             diagonaled_mm = DiagonaledMM._compile_function(dtype, device)  # compile
             DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device)  # save to disk
         # convert the tvm function into a pytorch function
         from tvm.contrib import dlpack
         diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm)  # wrap it as a pytorch function
         # save the function into a dictionary to be reused
         DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch  # save it in a dictionary for next time
     return DiagonaledMM.function_dict[args]
コード例 #4
0
ファイル: test_dlpack.py プロジェクト: vinceab/tvm
def test():
    a = np.random.randn(1337)
    tvm_a = tvm.nd.array(a)
    np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a)

    try:
        import torch
        import torch.utils.dlpack

        x = torch.rand(56, 56)
        tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
        y = tvm.nd.from_dlpack(tvm_x)
        np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
        np.testing.assert_equal(
            torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(),
            tvm_x.asnumpy())

        n = tvm.runtime.convert(137)
        xx = torch.rand(137, 137)
        yy = torch.rand(137, 137)
        zz2 = torch.empty(137, 137)
        zz = xx.mm(yy)
        XX = te.placeholder((n, n), name="X")
        YY = te.placeholder((n, n), name="Y")

        k = te.reduce_axis((0, n), name="k")
        ZZ = te.compute((n, n),
                        lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k))
        s = te.create_schedule(ZZ.op)
        # No need to speficy target_host if it's llvm
        # Otherwise you will need to specify the target and target_host
        f = tvm.build(s, [XX, YY, ZZ], name="f")

        f_pytorch = to_pytorch_func(f)
        zz2 = torch.empty(137, 137)
        f_pytorch(xx, yy, zz2)
        tvm.testing.assert_allclose(zz.numpy(),
                                    zz2.numpy(),
                                    rtol=1e-4,
                                    atol=1e-4)

    except ImportError:
        pass
コード例 #5
0
ファイル: test_dlpack.py プロジェクト: zhangquan920/tasn
def test():
    a = np.random.randn(1337)
    tvm_a = tvm.nd.array(a)
    np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a)

    try:
        import torch
        import torch.utils.dlpack

        x = torch.rand(56, 56)
        tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
        y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
        np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
        np.testing.assert_equal(
            torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(),
            tvm_x.asnumpy())

        n = tvm.convert(137)
        xx = torch.rand(137, 137)
        yy = torch.rand(137, 137)
        zz2 = torch.empty(137, 137)
        zz = xx.mm(yy)
        XX = tvm.placeholder((n, n), name='X')
        YY = tvm.placeholder((n, n), name='Y')

        k = tvm.reduce_axis((0, n), name='k')
        ZZ = tvm.compute((n, n),
                         lambda i, j: tvm.sum(XX[i, k] * YY[k, j], axis=k))
        s = tvm.create_schedule(ZZ.op)
        f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f')

        f_pytorch = to_pytorch_func(f)
        zz2 = torch.empty(137, 137)
        f_pytorch(xx, yy, zz2)
        np.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6)

    except ImportError:
        pass
コード例 #6
0
ファイル: hmm.py プロジェクト: muralits98/cascaded-generation
def fb_max(size):
    #with autotvm.apply_history_best(f'best_hmm_k{size}.log'):
    with tvm.target.create("cuda"):
        s_mult, arg_bufs = hmm_runner_max('float32', size)
        from tvm.contrib.dlpack import to_pytorch_func
        mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm")
        hmm_pytorch_max = to_pytorch_func(mod)

    def fb(x):
        time, batch, size, _ = x.shape
        forward = torch.zeros(time + 1, batch, size).cuda()
        y = log_eye_cat(x).cuda()
        hmm_pytorch_max(y, forward)
        del y
        #torch.cuda.empty_cache()

        backward = torch.zeros(time + 1, batch, size).cuda()
        y = log_eye_cat(x.flip(0).transpose(-2, -1)).contiguous().cuda()
        hmm_pytorch_max(y, backward)
        del y
        #torch.cuda.empty_cache()

        #check = (forward.view(time+1, batch, size)+
        #    backward.flip(0).view(time+1, batch, size))
        y = x.view(time, batch, size, size).transpose(-2,
                                                      -1).contiguous().cuda()
        y += forward[:-1].view(time, batch, 1, size)
        y += backward[:-1].flip(0).view(time, batch, size, 1)
        marginals = y.transpose(-2, -1)

        #marginals = (forward[:-1].view(time, batch, 1, size) +
        #             backward[:-1].flip(0).view(time, batch, size, 1) +
        #             x.view(time, batch, size, size).transpose(-2, -1)).transpose(-2, -1)
        #return forward, backward, marginals, check
        return marginals

    return fb
コード例 #7
0
ファイル: test_dlpack.py プロジェクト: LANHUIYING/tvm
def test():
    a = np.random.randn(1337)
    tvm_a = tvm.nd.array(a)
    np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a)

    try:
        import torch
        import torch.utils.dlpack

        x = torch.rand(56, 56)
        tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
        y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
        np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
        np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy())

        n = tvm.convert(137)
        xx = torch.rand(137,137)
        yy = torch.rand(137,137)
        zz2 = torch.empty(137,137)
        zz = xx.mm(yy)
        XX = tvm.placeholder((n,n), name='X')
        YY = tvm.placeholder((n,n), name='Y')

        k = tvm.reduce_axis((0, n), name='k')
        ZZ = tvm.compute((n,n), lambda i,j : tvm.sum(XX[i,k]*YY[k,j], axis=k))
        s = tvm.create_schedule(ZZ.op)
        f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f')

        f_pytorch = to_pytorch_func(f)
        zz2 = torch.empty(137,137)
        f_pytorch(xx, yy, zz2)
        tvm.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6)

    except ImportError:
        pass
コード例 #8
0
    cache_read(M, AA, AL, BB, BL, ko, ki)
    cache_read(M2, AA2, AL2, BB2, BL2, ko2, ki2)

    return s, [A, B, C]


task = autotvm.task.create(logsummul,
                           args=('float32', ),
                           target='cuda',
                           target_host="llvm")
with autotvm.apply_history_best('best.log'):
    with tvm.target.create("cuda"):
        s_mult, arg_bufs = logsummul('float32')
        mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm")
from tvm.contrib.dlpack import to_pytorch_func
logsum_pytorch = to_pytorch_func(mod)

if __name__ == "__main__":
    from tvm import autotvm

    task = autotvm.task.create(logsummul,
                               args=('float32', ),
                               target='cuda',
                               target_host="llvm")
    print(task.config_space)

    measure_option = autotvm.measure_option(
        builder=autotvm.LocalBuilder(n_parallel=5),
        runner=autotvm.LocalRunner(number=10,
                                   repeat=3,
                                   timeout=10,
コード例 #9
0
def get_fb(size):
    """
    with autotvm.apply_history_best(f'best_hmm_k{size}.log'):
        with tvm.target.create("cuda"):
            s_mult, arg_bufs = hmm_runner('float32', size)
            mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm")
            hmm_pytorch = to_pytorch_func(mod)
    """
    with tvm.target.create("cuda"):
        s_mult, arg_bufs = hmm_runner('float32', size)
        mod = tvm.build(s_mult, arg_bufs, target="cuda", target_host="llvm")
        hmm_pytorch = to_pytorch_func(mod)

    # if the padding doesn't make a difference this must be an inclusive scan
    # x: batch x time x zt x zt-1
    #@profile
    def fb(x, mask=None):
        batch, time, size, _ = x.shape
        lex = log_eye(size, dtype=x.dtype, device=x.device)
        # need time x batch x zt-1, zt
        x = x.permute(1, 0, 3, 2)
        if mask is not None:
            mask = mask.t()
            x[~mask[1:]] = lex
            """
            x.masked_scatter_(
                ~mask[1:,:,None,None],
                # EXPAND HERE IS BAD?
                lex[None,None].expand(x.shape),
            )
            import pdb; pdb.set_trace()
            """
        """
        out_fb = torch.empty(time+1, batch * 2, size, device=x.device)
        out_fb.fill_(float("-inf"))
        hmm_pytorch(
            log_eye_cat(torch.cat([x, x.flip(0).transpose(-2, -1)], 1)),
            out_fb,
        )

        out_fb2 = torch.empty(time+1, batch * 2, size, device=x.device)
        out_fb2.fill_(float("-inf"))
        hmm_pytorch(
            log_eye_cat(x),
            out_fb2[:,:batch],
        )
        hmm_pytorch(
            log_eye_cat(x.flip(0).transpose(-2, -1)),
            out_fb2[:,batch:],
        )
        alphas = out_fb[:, :batch]
        betas = out_fb[:, batch:].flip(0)
        """

        out_fb = torch.empty(2, time + 1, batch, size, device=x.device)
        out_fb.fill_(float("-inf"))
        inp = torch.empty(time + 1, batch, size, size, device=x.device)
        inp[-1] = lex
        # forward
        inp[:-1] = x
        hmm_pytorch(inp, out_fb[0])
        # backward
        inp[range(time - 1, -1, -1)] = x.transpose(-2, -1)
        hmm_pytorch(inp, out_fb[1])

        alphas = out_fb[0]
        betas = out_fb[1].flip(0)  # pay the memory cost here
        # not sure if i can flip the argument to hmm_pytorch

        log_marginals = x
        log_marginals += alphas[:-1].view(time, batch, size, 1)
        log_marginals += betas[1:].view(time, batch, 1, size)
        log_marginals -= alphas[-1].logsumexp(-1).view(1, -1, 1, 1)
        if mask is not None:
            log_marginals.masked_fill_(~mask[1:, :, None, None], float("-inf"))
        log_marginals = log_marginals.permute(1, 0, 3, 2)
        return log_marginals, alphas
        #marginals = log_marginals.exp()
        # switch back marginals: batch x time x zt x zt-1
        #return marginals, alphas, betas, log_marginals

    return fb