Пример #1
0
def test_incompatible_typecall_var_unification():
    solver = make_solver()
    gtv1 = relay.GlobalTypeVar("gtv1")
    gtv2 = relay.GlobalTypeVar("gtv2")

    t1 = relay.IncompleteType()
    t2 = relay.IncompleteType()

    tc1 = relay.TypeCall(gtv1, [t1])
    tc2 = relay.TypeCall(gtv2, [t2])
    solver.Unify(tc1, tc2)
Пример #2
0
def test_incompatible_typecall_args_unification():
    solver = make_solver()
    gtv = relay.GlobalTypeVar("gtv1")
    t1 = relay.IncompleteType()
    t2 = relay.IncompleteType()

    tensor1 = relay.TensorType((1, 2, 3), "float32")
    tensor2 = relay.TensorType((2, 3), "float32")
    tensor3 = relay.TensorType((3, ), "float32")

    tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
    tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
    solver.Unify(tc1, tc2)
Пример #3
0
def test_typecall_kind():
    gtv = relay.GlobalTypeVar("gtv")

    mod = tvm.IRModule()
    data = relay.TypeData(gtv, [], [])
    mod[gtv] = data
    empty_call = relay.TypeCall(gtv, [])
    assert check_kind(empty_call, mod) == relay.TypeKind.Type

    new_mod = tvm.IRModule()
    tv = relay.TypeVar("tv")
    new_data = relay.TypeData(gtv, [tv], [])
    new_mod[gtv] = new_data
    call = relay.TypeCall(gtv, [relay.TupleType([])])
    assert check_kind(call, new_mod) == relay.TypeKind.Type
Пример #4
0
def test_alloc_tensor():
    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
        return

    mod = tvm.IRModule()
    mod.import_from_std("core.rly")
    sto_type = relay.TypeCall(mod.get_global_type_var("Storage"), [])
    sto = relay.Var("x", sto_type)
    sh = relay.const(np.array([3, 2]), dtype="int64")
    at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh)
    mod["main"] = relay.Function([sto], at)
    ca = context_analysis(mod, tvm.cuda())
    main = mod["main"]
    body = main.body

    cpu_dev = tvm.cpu().device_type
    gpu_dev = tvm.cuda().device_type
    # Input of the function falls back to the default device gpu
    assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev

    assert isinstance(body, relay.Call) and len(body.args) == 3
    # storage of alloc_tensor falls back to the default device gpu
    assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev
    # shape of alloc_tensor is on cpu
    assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev
    # alloc_tensor keeps the same device context as storage which is is on gpu
    assert body in ca and ca[body][0].value == gpu_dev
Пример #5
0
def test_typecall_invalid_num_args():
    mod = tvm.IRModule()
    gtv = relay.GlobalTypeVar("v1")
    tv = relay.TypeVar("tv")
    data = relay.TypeData(gtv, [tv], [])
    mod[gtv] = data
    check_kind(relay.TypeCall(gtv, []))
Пример #6
0
def test_typecall_invalid_num_args():
    mod = relay.Module()
    gtv = relay.GlobalTypeVar('v1')
    tv = relay.TypeVar('tv')
    data = relay.TypeData(gtv, [tv], [])
    mod[gtv] = data
    check_kind(relay.TypeCall(gtv, []))
Пример #7
0
def test_typecall_invalid_args():
    # args must all be type kind
    mod = tvm.IRModule()
    gtv = relay.GlobalTypeVar("v1")
    data = relay.TypeData(gtv, [], [])
    mod[gtv] = data

    check_kind(relay.TypeCall(gtv, [data]))
Пример #8
0
def test_type_call_alpha_equal():
    h1 = relay.GlobalTypeVar("h1")
    h2 = relay.GlobalTypeVar("h2")
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")
    t3 = relay.TensorType((1, 2, 3, 4), "float32")
    t4 = relay.TensorType((), "float32")

    tc = relay.TypeCall(h1, [t1, t2, t3])
    same = relay.TypeCall(h1, [t1, t2, t3])

    different_func = relay.TypeCall(h2, [t1, t2, t3])
    different_arg = relay.TypeCall(h1, [t1, t2, t4])
    fewer_args = relay.TypeCall(h1, [t1, t2])
    more_args = relay.TypeCall(h1, [t1, t2, t3, t4])
    different_order_args = relay.TypeCall(h1, [t3, t2, t1])

    assert tc == same
    assert tc != different_func
    assert tc != fewer_args
    assert tc != more_args
    assert tc != different_order_args
Пример #9
0
def test_typecall_invalid_callee():
    # global type var must be an ADT handle
    gtv = relay.GlobalTypeVar("v1", relay.TypeKind.Type)
    check_kind(relay.TypeCall(gtv, []))
Пример #10
0
def storage_type(mod):
    return relay.TypeCall(mod.get_global_type_var("Storage"), [])