Esempio n. 1
0
def test_serializer():
    mod = rly_module({})
    a = relay.const(1.0, "float32")
    x = relay.var('x', shape=(10, 10), dtype='float32')
    f1 = relay.Function([x], x + a)
    glb_f1 = relay.GlobalVar("f1")
    mod[glb_f1] = f1

    b = relay.const(2.0, "float32")
    y = relay.var('y', shape=(10, 10), dtype='float32')
    f2 = relay.Function([y], y - b)
    glb_f2 = relay.GlobalVar("f2")
    mod[glb_f2] = f2

    x1 = relay.var('x1', shape=(10, 10), dtype='float32')
    y1 = relay.var('y1', shape=(10, 10), dtype='float32')
    main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
    mod["main"] = main

    vm = create_vm(mod)
    ser = serializer.Serializer(vm)

    stats = ser.stats
    assert "scalar" in stats

    glbs = ser.globals
    assert len(glbs) == 3
    assert "f1" in glbs
    assert "f2" in glbs
    assert "main" in glbs

    prim_ops = ser.primitive_ops
    assert any(item.startswith('fused_add') for item in prim_ops)
    assert any(item.startswith('fused_subtract') for item in prim_ops)
    assert any(item.startswith('fused_multiply') for item in prim_ops)

    code = ser.bytecode
    assert "main 5 2 5" in code
    assert "f1 3 1 4" in code
    assert "f2 3 1 4" in code

    code, lib = ser.serialize()
    assert isinstance(code, bytearray)
    assert isinstance(lib, tvm.module.Module)
def test_serializer():
    mod = rly_module({})
    a = relay.const(1.0, "float32")
    x = relay.var('x', shape=(10, 10), dtype='float32')
    f1 = relay.Function([x], x + a)
    glb_f1 = relay.GlobalVar("f1")
    mod[glb_f1] = f1

    b = relay.const(2.0, "float32")
    y = relay.var('y', shape=(10, 10), dtype='float32')
    f2 = relay.Function([y], y - b)
    glb_f2 = relay.GlobalVar("f2")
    mod[glb_f2] = f2

    x1 = relay.var('x1', shape=(10, 10), dtype='float32')
    y1 = relay.var('y1', shape=(10, 10), dtype='float32')
    main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
    mod["main"] = main

    exe = create_exec(mod)

    glbs = exe.globals
    assert len(glbs) == 3
    assert "f1" in glbs
    assert "f2" in glbs
    assert "main" in glbs

    prim_ops = exe.primitive_ops
    assert any(item.startswith('fused_add') for item in prim_ops)
    assert any(item.startswith('fused_subtract') for item in prim_ops)
    assert any(item.startswith('fused_multiply') for item in prim_ops)

    code = exe.bytecode
    assert "main(x1, y1)" in code
    assert "f1(x)" in code
    assert "f2(y)" in code

    code, lib = exe.save()
    assert isinstance(code, bytearray)
    assert isinstance(lib, tvm.runtime.Module)