예제 #1
0
 def check_llvm():
     if not tvm.testing.device_enabled("llvm"):
         return
     f = tvm.build(s, [A, B], "ext_dev", "llvm")
     ctx = tvm.ext_dev(0)
     # launch the kernel.
     a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
     b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
예제 #2
0
파일: test_ext.py 프로젝트: LANHUIYING/tvm
 def check_llvm():
     if not tvm.module.enabled("llvm"):
         return
     f = tvm.build(s, [A, B], "ext_dev", "llvm")
     ctx = tvm.ext_dev(0)
     # launch the kernel.
     a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
     b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
예제 #3
0
# under the License.
"""Python VTA Deploy."""
from __future__ import absolute_import, print_function

import os
from os.path import join
from io import BytesIO
from PIL import Image

import requests
import numpy as np

import tvm
from tvm.contrib import graph_runtime, download

CTX = tvm.ext_dev(0)


def load_vta_library():
    """load vta lib"""
    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
    proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
    vtadll = os.path.abspath(os.path.join(proj_root, "build/libvta.so"))
    return tvm.runtime.load_module(vtadll)


def load_model():
    """ Load VTA Model  """

    load_vta_library()