Ejemplo n.º 1
0
    def test_resnet_infer_with_feature(self):
        cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg"
        import jittor_utils
        cat_path = f"{jt.flags.cache_path}/cat.jpg"
        print("download")
        jittor_utils.download(cat_url, cat_path)
        with open(cat_path, 'rb') as f:
            img = Image.open(f).convert('RGB')
            img = jt.array(np.array(img))
            print(img.shape, img.dtype)
            img = ((img.float() - 128) / 255).transpose(2, 0, 1)

        with jt.flag_scope(trace_py_var=2, trace_var_data=1):
            img = img[None, ...]

            resnet18 = resnet.Resnet18(pretrained=True)
            x = jt.float32(img)
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl",
                      "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
Ejemplo n.º 2
0
#     os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
libname = {"clang": "omp", "icc": "iomp5", "g++": "gomp"}[cc_type]
libname = ctypes.util.find_library(libname)
assert libname is not None, "openmp library not found"
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)

version_file = os.path.join(jittor_path, "version")
if os.path.isfile(version_file):
    with open(version_file, 'r') as f:
        version = f.read().strip()
    key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
    # TODO: open the website
    extra_obj = os.path.join(cache_path, key)
    url = os.path.join("https://cg.cs.tsinghua.edu.cn/jittor/assets/build/" +
                       key)
    jit_utils.download(url, extra_obj)
    files.append(extra_obj)

compile(cc_path, cc_flags + opt_flags, files, 'jittor_core' + extension_suffix)

# TODO: move to compile_extern.py
compile_extern()

with jit_utils.import_scope(import_flags):
    import jittor_core as core

flags = core.flags()
if has_cuda:
    nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "

flags.cc_path = cc_path