Example #1
0
def test_torch_dialect():
    pytorch_path = os.path.join(curr_path, "..", "dummy_repo", "pytorch")

    server = langserver.BaseServer()
    uri = langserver.path2uri(pytorch_path)
    server.m_initialize(rootUri=uri)

    res = run_find_definition(
        server, join_path(pytorch_path, "torch/nn/quantized/modules/conv.py"),
        38, 28)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("qconv.cpp"))
    assert (res[0]['range']['start']['line'] == 2)

    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              20, 50)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 95)

    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              25, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 1)
    assert (res[0]['range']['end']['character'] == 27)

    res = run_find_definition(
        server, join_path(pytorch_path, "torch/nn/functional.py"), 16, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_torch_functions.cpp"))
    assert (res[0]['range']['start']['line'] == 2)
def test_taichi_dialect():
    ti_path = os.path.join(curr_path, "..", "dummy_repo", "taichi")
    server = langserver.BaseServer()
    uri = langserver.path2uri(ti_path)
    server.m_initialize(rootUri=uri)

    # ti.core.global_var_expr_from_snode
    res = run_find_definition(
        server, join_path(ti_path, "python/taichi/lang/snode.py"), 4, 40)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_bindings.cpp"))
    assert (res[0]['range']['start']['line'] == 8)

    # taichi_lang_core.create_kernel
    res = run_find_definition(
        server, join_path(ti_path, "python/taichi/lang/kernel.py"), 74, 40)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_bindings.cpp"))
    assert (res[0]['range']['start']['line'] == 11)

    # tc_core.Array2DVector4
    res = run_find_definition(server,
                              join_path(ti_path, "python/taichi/misc/util.py"),
                              34, 40)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("export_math.cpp"))
    assert (res[0]['range']['start']['line'] == 12)

    # core.get_current_program()
    res = run_find_definition(
        server, join_path(ti_path, "python/taichi/lang/__init__.py"), 10, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_bindings.cpp"))
    assert (res[0]['range']['start']['line'] == 15)
    def test_real_repo():
        # tested on tvm git tag e69bd1284b50630df570b3a5779a801982203756
        tvm_path = os.path.join(curr_path, "..", "..", "..", "tvm")
        if not os.path.exists(tvm_path):
            logging.info("Skip tvm tests")
            return

        server = langserver.BaseServer()
        server.m_initialize(rootUri=langserver.path2uri(tvm_path))

        run_find_references(server, join_path(tvm_path,
                                              "src/runtime/module.cc"), 198,
                            34)

        run_find_references(server, join_path(tvm_path, "python/tvm/api.py"),
                            58, 33)

        run_find_definition(server,
                            join_path(tvm_path, "python/tvm/relay/expr.py"),
                            177, 14)

        run_find_references(server, join_path(tvm_path,
                                              "src/relay/ir/expr.cc"), 39, 33)

        run_find_definition(server, join_path(tvm_path, "python/tvm/stmt.py"),
                            96, 34)

        run_find_references(server, join_path(tvm_path, "python/tvm/stmt.py"),
                            96, 34)

        run_find_definition(server, join_path(tvm_path, "python/tvm/stmt.py"),
                            56, 18)

        run_find_references(server, join_path(tvm_path, "python/tvm/stmt.py"),
                            56, 18)

        run_find_definition(
            server, join_path(tvm_path, "src/relay/backend/compile_engine.cc"),
            730, 59)

        run_find_references(
            server, join_path(tvm_path, "src/relay/backend/compile_engine.cc"),
            730, 59)

        # TVM_REGISTER_API("ir_pass.Simplify")
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_pass.cc"),
                                  33, 30)
        assert (len(res) == 6)

        # _pass.Simplify(end - begin)
        res = run_find_references(
            server, join_path(tvm_path, "python/tvm/ir_builder.py"), 214, 48)
        assert (len(res) == 6)

        # REGISTER_MAKE(Provide);
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_ir.cc"),
                                  156, 15)
        assert (len(res) == 6)
Example #4
0
def run_find_definition(server, path, line, character):
    uri = langserver.path2uri(path)
    res = server.m_text_document__definition(textDocument={"uri": uri},
                                             position={
                                                 "line": line,
                                                 "character": character
                                             })
    return res
Example #5
0
def test_dgl_dialect():
    dgl_path = os.path.join(curr_path, "..", "dummy_repo", "dgl")
    server = langserver.BaseServer()
    uri = langserver.path2uri(dgl_path)
    server.m_initialize(rootUri=uri)

    res = run_find_definition(server,
                              join_path(dgl_path, "python/dgl/nodeflow.py"),
                              16, 20)
Example #6
0
def test_mxnet_dialect():
    mx_path = os.path.join(curr_path, "..", "dummy_repo", "mxnet")
    server = langserver.BaseServer()
    uri = langserver.path2uri(mx_path)
    server.m_initialize(rootUri=uri)

    res = run_find_definition(server,
                              join_path(mx_path, "python/mxnet/executor.py"),
                              55, 35)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("c_api_executor.cc"))
    assert (res[0]['range']['start']['line'] == 25)
Example #7
0
def run_check_langserver(tvm_path):
    server = langserver.BaseServer()
    uri = langserver.path2uri(tvm_path)
    server.m_initialize(rootUri=uri)

    uri = langserver.path2uri(os.path.join(tvm_path, "python/tvm/api.py"))
    server.m_text_document__references(
        textDocument={"uri": uri},
        position={"line": 58, "character": 33 })

    uri = langserver.path2uri(os.path.join(tvm_path, "python/tvm/relay/expr.py"))
    server.m_text_document__definition(
        textDocument={"uri": uri},
        position={"line": 177, "character": 14 })

    uri = langserver.path2uri(os.path.join(tvm_path, "src/relay/ir/expr.cc"))
    server.m_text_document__references(
        textDocument={"uri": uri},
        position={"line": 39, "character": 33 })

    uri = langserver.path2uri(os.path.join(tvm_path, "python/tvm/stmt.py"))
    ret1 = server.m_text_document__definition(
        **{'textDocument': {'uri': uri}, 'position': {'line': 96, 'character': 34}})
    ret1 = server.m_text_document__references(
        **{'textDocument': {'uri': uri}, 'position': {'line': 96, 'character': 34}})

    server.m_text_document__references(
        textDocument={"uri": uri},
        position={"line": 56, "character": 18 })

    ret0 = server.m_text_document__definition(
        textDocument={"uri": uri},
        position={"line": 56, "character": 18 })

    uri = langserver.path2uri(os.path.join(tvm_path, "src/relay/backend/compile_engine.cc"))
    server.m_text_document__definition(
        textDocument={"uri": uri},
        position={"line": 727, "character": 59 })

    server.m_text_document__references(
        textDocument={"uri": uri},
        position={"line": 727, "character": 59 })
Example #8
0
    def test_dummy_repo():
        # test and verify against dummy repo
        tvm_path = os.path.join(curr_path, "..", "dummy_repo", "tvm")
        server = langserver.BaseServer()
        server.m_initialize(rootUri=langserver.path2uri(tvm_path))

        # Constant
        res = run_find_definition(
            server, join_path(tvm_path, "python/tvm/relay/expr.py"), 15, 14)
        assert (len(res) == 1)
        assert (res[0]['uri'].endswith("expr.h"))
        assert (res[0]['range']['start']['line'] == 33)

        # _make.ProducerConsumer
        res = run_find_definition(server,
                                  join_path(tvm_path, "python/tvm/stmt.py"),
                                  26, 30)
        assert (len(res) == 1)
        assert (res[0]['uri'].endswith("api_ir.cc"))
        assert (res[0]['range']['start']['line'] == 14)

        # _make.LetStmt
        res = run_find_definition(server,
                                  join_path(tvm_path, "python/tvm/stmt.py"),
                                  46, 20)
        assert (len(res) == 1)
        assert (res[0]['uri'].endswith("api_ir.cc"))
        assert (res[0]['range']['start']['line'] == 15)

        # Get("relay.backend.lower") from c++ to python
        res = run_find_definition(
            server, join_path(tvm_path, "src/relay/backend/compile_engine.cc"),
            74, 59)
        assert (len(res) == 1)
        assert (res[0]['uri'].endswith("_backend.py"))
        assert (res[0]['range']['start']['line'] == 8)

        # Variable
        res = run_find_references(server,
                                  join_path(tvm_path, "include/tvm/expr.h"),
                                  15, 49)
        assert (len(res) == 2)
        assert (res[1]['uri'].endswith("expr.py"))
        assert (res[1]['range']['start']['line'] == 15)

        # TVM_REGISTER_GLOBAL("_min_value")
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_lang.cc"),
                                  15, 33)
        assert (len(res) == 2)
        assert (res[1]['uri'].endswith("api.py"))
        assert (res[1]['range']['start']['line'] == 24)

        # _make.Constant
        res = run_find_references(server,
                                  join_path(tvm_path, "src/relay/ir/expr.cc"),
                                  16, 33)
        assert (len(res) == 2)
        assert (res[1]['uri'].endswith("expr.py"))
        assert (res[1]['range']['start']['line'] == 24)

        # REGISTER_MAKE(ProducerConsumer)
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_ir.cc"), 14,
                                  25)
        assert (len(res) == 2)
        assert (res[1]['uri'].endswith("stmt.py"))
        assert (res[1]['range']['start']['line'] == 26)

        # REGISTER_MAKE(LetStmt)
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_ir.cc"), 15,
                                  18)
        assert (len(res) == 2)
        assert (res[1]['uri'].endswith("stmt.py"))
        assert (res[1]['range']['start']['line'] == 46)

        # @register_func("relay.backend.build")
        res = run_find_references(
            server, join_path(tvm_path,
                              "python/tvm/relay/backend/_backend.py"), 26, 30)
        assert (len(res) == 3)
        assert (res[1]['uri'].endswith("compile_engine.cc"))
        assert (res[1]['range']['start']['line'] == 90)
        assert (res[2]['uri'].endswith("interpreter.cc"))
        assert (res[2]['range']['start']['line'] == 115)

        # _pass.Simplify(end - begin)
        res = run_find_references(
            server, join_path(tvm_path, "python/tvm/ir_builder.py"), 20, 48)
        assert (len(res) == 6)
        assert (res[0]['uri'].endswith("api_pass.cc"))
        assert (res[0]['range']['start']['line'] == 10)
        assert (res[1]['uri'].endswith(normalize_path("autotvm/util.py")))
        assert (res[1]['range']['start']['line'] == 26)
        assert (res[2]['uri'].endswith(normalize_path("autotvm/util.py")))
        assert (res[2]['range']['start']['line'] == 50)
        assert (res[3]['uri'].endswith("build_module.py"))
        assert (res[3]['range']['start']['line'] == 98)
        assert (res[4]['uri'].endswith(normalize_path("hybrid/parser.py")))
        assert (res[4]['range']['start']['line'] == 43)

        # REGISTER_MAKE(Provide);
        res = run_find_references(server,
                                  join_path(tvm_path, "src/api/api_ir.cc"), 16,
                                  15)
        assert (len(res) == 6)
        assert (res[1]['uri'].endswith(normalize_path("hybrid/parser.py")))
        assert (res[1]['range']['start']['line'] == 75)
        assert (res[2]['uri'].endswith(normalize_path("hybrid/parser.py")))
        assert (res[2]['range']['start']['line'] == 81)
        assert (res[3]['uri'].endswith(normalize_path("hybrid/parser.py")))
        assert (res[3]['range']['start']['line'] == 97)
        assert (res[4]['uri'].endswith(normalize_path("hybrid/util.py")))
        assert (res[4]['range']['start']['line'] == 20)
        assert (res[5]['uri'].endswith("stmt.py"))
        assert (res[5]['range']['start']['line'] == 68)
def test_torch_dialect():
    pytorch_path = os.path.join(curr_path, "..", "dummy_repo", "pytorch")

    server = langserver.BaseServer()
    uri = langserver.path2uri(pytorch_path)
    server.m_initialize(rootUri=uri)

    # ops.quantized.conv2d
    res = run_find_definition(
        server, join_path(pytorch_path, "torch/nn/quantized/modules/conv.py"),
        38, 28)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("qconv.cpp"))
    assert (res[0]['range']['start']['line'] == 2)

    # torch._C._jit_script_class_compile
    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              20, 50)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 126)

    # torch._C.CompilationUnit()
    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              25, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 1)
    assert (res[0]['range']['end']['character'] == 27)

    # torch.conv2d
    res = run_find_definition(
        server, join_path(pytorch_path, "torch/nn/functional.py"), 16, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_torch_functions.cpp"))
    assert (res[0]['range']['start']['line'] == 2)

    # module._c._create_method_from_trace
    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              61, 30)
    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 105)

    # self._c._get_method(attr)
    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              106, 30)

    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 21)

    # self._c._define(self._concrete_type, src, rcb)
    res = run_find_definition(server,
                              join_path(pytorch_path, "torch/jit/__init__.py"),
                              98, 18)

    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("init.cpp"))
    assert (res[0]['range']['start']['line'] == 94)

    # Variable._execution_engine.run_backward
    res = run_find_definition(
        server, join_path(pytorch_path, "torch/autograd/__init__.py"), 24, 40)

    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_engine.cpp"))
    assert (res[0]['range']['start']['line'] == 13)

    # _C._FunctionBase._do_forward
    res = run_find_definition(
        server, join_path(pytorch_path, "torch/autograd/function.py"), 5, 40)

    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("python_function.cpp"))
    assert (res[0]['range']['start']['line'] == 11)

    # torch._C._get_qengine()
    res = run_find_definition(
        server, join_path(pytorch_path,
                          "torch/backends/quantized/__init__.py"), 6, 45)

    assert (len(res) > 0)
    assert (res[0]['uri'].endswith("Module.cpp"))
    assert (res[0]['range']['start']['line'] == 46)