예제 #1
0
def test_storage_share():
    m = te.var("m")
    l = te.var("l")
    A = te.placeholder((m, l), name="A")
    num_stage = 5
    B = A
    for t in range(num_stage):
        B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t)

    s = te.create_schedule(B.op)
    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    # verify only have one allocations.
    # verify inplace folding works
    num_alloc = [0]

    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            num_alloc[0] += 1

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert num_alloc[0] == 1
예제 #2
0
def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024):
    # Test Buffer
    register_mem(scope_tb, max_bits)
    m = 10
    A = te.placeholder((m, ), name="A")
    C = te.placeholder((m, ), name="C")
    D = te.placeholder((m, ), name="D")
    A0 = te.compute((m, ), lambda i: A[i] + C[i], name="A0")
    A1 = te.compute((m, ), lambda i: D[i] * D[i], name="A1")
    A2 = te.compute((m, ), lambda i: A0[i] + A1[i], name="A2")
    B = te.compute((m, ), lambda i: A2[i], name="B")
    s = te.create_schedule(B.op)
    A0L = s.cache_read(A0, scope_tb, [A2])
    A1L = s.cache_read(A1, scope_tb, [A2])
    A2L = s.cache_read(A2, scope_tb, [B])
    mod = schedule_to_module(s, [A, B, C, D])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    # verify only have one allocations.
    # verify inplace folding works
    num_alloc = [0]

    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            num_alloc[0] += 1

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert num_alloc[0] == 2
예제 #3
0
def test_storage_combine_with_vectorization():
    n = 1024
    A = te.placeholder((n, ), name="A")
    B = te.placeholder((n, ), name="B")
    C = te.compute((n, ), lambda i: A[i] + B[i], name="C")
    s = te.create_schedule(C.op)
    AA = s.cache_read(A, "global:tag", readers=[C])
    BB = s.cache_read(B, "global:tag", readers=[C])
    CC = s.cache_write(C, "global:tag")
    s[CC].vectorize(s[CC].op.axis[0])
    mod = schedule_to_module(s, [A, B, C])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)
    mod = tvm.tir.transform.VectorizeLoop()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    mod = tvm.tir.transform.Simplify()(mod)
    stmt = mod["main"].body
    num_alloc = [0]

    def verify(v):
        # find add op
        if (isinstance(v, tvm.tir.Add) and isinstance(v.a, tvm.tir.Load)
                and isinstance(v.b, tvm.tir.Load)):
            lhs_ramp = v.a.index
            rhs_ramp = v.b.index
            # these two ramp load should not overlap
            assert lhs_ramp.lanes == n
            assert rhs_ramp.lanes == n
            assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n
        elif isinstance(v, tvm.tir.Allocate):
            num_alloc[0] += 1

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert num_alloc[0] == 1
def test_copy_pad_split():
    m = 4 * 3
    A = te.placeholder((m, ), name="A")
    Apad = te.compute((m + 2, ), lambda i: tvm.tir.if_then_else(
        tvm.tir.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad")
    B = te.compute((m, ), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
    s = te.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=4)
    s[Apad].compute_at(s[B], xo)
    s[Apad].pragma(s[Apad].op.axis[0], "memcpy")

    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
    mod = tvm.tir.transform.Simplify()(mod._move())

    def cb(src, dst, pad_before, pad_after, pad_value):
        assert dst.elem_offset.value == 0
        tvm.testing.assert_prim_expr_equal(src.elem_offset,
                                           tvm.te.max(xo * 4, 1) - 1)

        rpad_before = tvm.te.max(1 - xo * 4, 0)
        rpad_after = tvm.te.max(xo * 4 - 7, 0)
        tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before)
        tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after)
        tvm.testing.assert_prim_expr_equal(src.shape[0],
                                           6 - rpad_before - rpad_after)
        return tvm.tir.Evaluate(0)

    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
예제 #5
0
def test_inplace_rule():
    m = 10
    A = te.placeholder((m, ), name="A")
    A0 = te.compute((m, ), lambda i: A[i], name="A0")
    A1 = te.compute((m, ), lambda i: A[i] + 1, name="A1")
    AA = te.compute((m, ), lambda i: A0[i] + A1[i] + A1[0], name="AA")
    B = te.compute((m, ), lambda i: AA[i] + 1, name="B")
    s = te.create_schedule(B.op)
    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    # verify only have one allocations.
    # verify inplace folding works
    num_alloc = [0]

    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            num_alloc[0] += 1

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert num_alloc[0] == 2
def test_copy_pad():
    m = te.var("m")
    l = te.var("l")
    A = te.placeholder((m, l), name="A")
    B = te.compute(
        (m + 2, l),
        lambda i, j: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1), A[
            i - 1, j], 1.0),
        name="B",
    )
    s = te.create_schedule(B.op)
    s[B].pragma(B.op.axis[0], "memcpy")
    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    def cb(src, dst, pad_before, pad_after, pad_value):
        tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
        assert pad_before[0].value == 1
        assert pad_before[1].value == 0
        assert pad_after[0].value == 1
        assert pad_after[1].value == 0
        assert pad_value.value == 1.0
        return tvm.tir.Evaluate(0)

    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
예제 #7
0
def test_storage_share_gpu():
    m = te.var("m")
    A = [te.placeholder((m), name="A")]
    num_stage = 5
    for t in range(num_stage):
        A.append(
            te.compute((m, ), lambda i: A[-1][i] + (t + 1), name="A%d_s" % t))
        A.append(te.compute((m, ), lambda i: A[-1][i], name="A%d" % t))
    s = te.create_schedule(A[-1].op)
    for t in range(num_stage):
        x = A[2 * t + 2].op.axis[0]
        bx, tx = s[A[2 * t + 2]].split(x, factor=32)
        s[A[2 * t + 2]].bind(bx, te.thread_axis("blockIdx.x"))
        s[A[2 * t + 2]].bind(tx, te.thread_axis("threadIdx.x"))
        s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx)
        s[A[2 * t + 1]].set_scope("shared")

    mod = schedule_to_module(s, [A[0], A[-1]])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)
    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    alloc_stats = {"global": 0, "shared": 0}

    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            scope = n.buffer_var.type_annotation.storage_scope
            alloc_stats[scope] += 1

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert alloc_stats["global"] == 2
    assert alloc_stats["shared"] == num_stage
예제 #8
0
def test_storage_combine():
    n = 8
    A = te.placeholder((4, ), name="A")
    num_stage = 5
    B = A
    stages = []
    for t in range(num_stage):
        B = te.compute((n, ), lambda i: B[i] + B[0] + (t + 1), name="A%d" % t)
        stages.append(B)

    s = te.create_schedule(B.op)
    for S in stages[:-1]:
        s[S].set_scope("global:tag")

    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    num_alloc = [0]

    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            num_alloc[0] += 1
            assert n.extents[0].value == 16

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
    assert num_alloc[0] == 1
예제 #9
0
    def test_internal_reshape(self, target, dev, n_items, reordered_shape,
                              dtype, fphysical_layout):
        # The reshaping of the buffer gets flattened away in
        # StorageFlatten.  Therefore, testing the behavior by running only
        # ApplyLayoutTransforms.
        logical_shape = (n_items, )
        A = te.placeholder(logical_shape, name="A", dtype=dtype)
        B = te.compute(shape=logical_shape, fcompute=lambda i: A[i], name="B")
        C = te.compute(shape=logical_shape, fcompute=lambda i: B[i], name="C")

        s = te.create_schedule(C.op)
        s[B].transform_layout(fphysical_layout)

        mod = schedule_to_module(s, [A, C])
        body = mod["main"].body

        def walk_buffer_interactions(stmt, callback):
            buffer_classes = [
                tvm.tir.BufferLoad,
                tvm.tir.BufferStore,
                tvm.tir.BufferRealize,
            ]

            def inner(node):
                if (type(node) in buffer_classes) and node.buffer.name == "B":
                    callback(node)

            post_order_visit(stmt, inner)

        # All references to the buffer are the same object
        def check_references():
            buffer_object = None

            def inner(node):
                nonlocal buffer_object
                if buffer_object is None:
                    buffer_object = node.buffer
                else:
                    assert node.buffer.same_as(buffer_object)

            return inner

        # The buffer has the expected shape.
        def check_shape(expected_shape):
            def inner(node):
                assert tuple(node.buffer.shape) == expected_shape

            return inner

        # Before the transform, the buffer should be in the logical shape.
        walk_buffer_interactions(body, check_references())
        walk_buffer_interactions(body, check_shape(logical_shape))

        mod = tvm.tir.transform.ApplyLayoutTransforms()(mod)
        body = mod["main"].body

        # After the transform, the buffer should be in the physical shape.
        walk_buffer_interactions(body, check_references())
        walk_buffer_interactions(body, check_shape(reordered_shape))
예제 #10
0
def test_schedule0():
    m = te.var("m")
    l = te.var("l")
    A = te.placeholder((m, l), name="A")
    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
    s = te.create_schedule(A1.op)

    mod = schedule_to_module(s, [A, A1])
    assert isinstance(mod["main"], tvm.tir.PrimFunc)
예제 #11
0
def run_passes(sch, args):
    mod = schedule_to_module(sch, args)
    return tvm.transform.Sequential([
        tvm.tir.transform.StorageFlatten(64),
        tvm.tir.transform.Simplify(),
        tvm.tir.transform.VectorizeLoop(),
        tvm.tir.transform.StorageRewrite(),
        tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
    ])(mod)
예제 #12
0
 def run_passes(sch, args):
     mod = schedule_to_module(sch, args)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(
         mod)
     return tvm.transform.Sequential([
         tvm.tir.transform.StorageFlatten(64),
         tvm.tir.transform.Simplify(),
         tvm.tir.transform.StorageRewrite(),
         tvm.tir.transform.LowerThreadAllreduce(),
     ])(mod)
예제 #13
0
def test_schedule1():
    m = te.var("m")
    l = te.var("l")
    A = te.placeholder((m, l), name="A")
    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")

    s = te.create_schedule(A1.op)
    xo, xi = s[A1].split(A1.op.axis[0], 8)
    s[A1].pragma(xo, "auto_unroll_max_step", 10)

    mod = schedule_to_module(s, [A, A1])
    assert isinstance(mod["main"], tvm.tir.PrimFunc)
예제 #14
0
def test_schedule2():
    m = te.var("m")
    l = te.var("l")
    A = te.placeholder((m, l), name="A")
    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")

    s = te.create_schedule(A2.op)
    xo, xi = s[A2].split(A2.op.axis[0], 8)
    s[A1].compute_at(s[A2], xo)

    mod = schedule_to_module(s, [A, A2])
    assert isinstance(mod["main"], tvm.tir.PrimFunc)
예제 #15
0
def ana_lower(sch, args, binds=None, simple_mode=True):
    """Do lower while keeping all axes in IR
    i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
    """
    sch = sch.normalize()
    # Phase 0
    context = tvm.transform.PassContext(
        config={"tir.debug_keep_trivial_loop": True})
    with context:
        mod = build_module.schedule_to_module(sch, args, binds=binds)

    mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
    mod = tvm.tir.transform.Simplify()(mod._move())
    assert simple_mode
    return mod["main"].body
예제 #16
0
def test_flatten2():
    m = te.size_var("m")
    l = te.size_var("l")
    A = te.placeholder((m, l), name="A")
    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")

    s = te.create_schedule(A2.op)
    xo, xi = s[A2].split(A2.op.axis[0], 8)
    s[A1].compute_at(s[A2], xo)
    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A")
    A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2")

    mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b})
    mod = tvm.tir.transform.StorageFlatten(64)(mod)
def test_single_point_test():
    A = te.placeholder((1, ), name="A")
    B = te.compute((1, ), lambda i: A[i], name="B")
    s = te.create_schedule(B.op)
    s[B].pragma(B.op.axis[0], "memcpy")
    mod = schedule_to_module(s, [A, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    def cb(src, dst, pad_before, pad_after, pad_value):
        tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
        tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0)
        tvm.testing.assert_prim_expr_equal(src.strides[0], 1)
        tvm.testing.assert_prim_expr_equal(dst.strides[0], 1)
        return tvm.tir.Evaluate(0)

    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def lower_sch(sch, args, target_bits):
    binds = {}
    arg_list = []
    for x in args:
        if isinstance(x, te.tensor.Tensor):
            buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
            assert x not in binds
            binds[x] = buf
            arg_list.append(buf)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    sch = sch.normalize()

    mod = schedule_to_module(sch, args)
    mod = tvm.tir.transform.StorageFlatten(64)(mod)
    return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
예제 #19
0
def test_flatten_storage_align():
    m = 8
    l = 16
    A = te.placeholder((m, l), name="A")
    A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")

    s = te.create_schedule(A2.op)
    s[A1].storage_align(A1.op.axis[0], 2, 1)

    mod = schedule_to_module(s, [A, A2])
    mod = tvm.transform.Sequential(
        [tvm.tir.transform.StorageFlatten(64),
         tvm.tir.transform.Simplify()])(mod)

    stmt = mod["main"].body
    assert stmt.extents[0].value == 17 * 8
예제 #20
0
def test_makeapi():
    """Not yet working, mock design"""
    n = te.size_var("n")
    A = te.placeholder((n, ), name="A")
    B = te.placeholder((n, ), name="B")
    C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
    s = te.create_schedule(C.op)

    mod = schedule_to_module(s, [n, A, B, C])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)
    mod = tvm.tir.transform.Apply(
        lambda f: f.with_attr({
            "target": tvm.target.Target("llvm"),
            "global_symbol": "main",
        }))(mod)

    num_unpacked_args = 2
    f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
    assert len(f.params) == 8
예제 #21
0
def lower_ethosu(sch, args, const_dict, name="main"):
    """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.

    The resulting TIR module will contain a single function
    that consists of a sequence of tir.call_extern to NPU
    operations.

    Parameters
    ----------
    sch : tvm.te.Schedule
        The schedule to be lowered.
    args : Union[list of tvm.te.Tensor, TEGraph]
        The input/output tensors.
    const_dict : dict of int to numpy.ndarray
        The constant dictionary.
    name : str, optional
        The name of the lowered primitive function.

    Returns
    -------
    mod : tvm.IRModule
        The lowered TIR module.
    const_dict : dict of int to numpy.ndarray
        The modified constant dictionary.

    """
    if not isinstance(args, list):
        args = list(args.inputs) + list(args.outputs)
    # config setup
    curr_pass_ctx = tvm.ir.transform.PassContext.current()
    curr_cfg = dict()
    for key, value in curr_pass_ctx.config.items():
        curr_cfg[key] = value
    tir_compiler_cfg = {
        "tir.LoopPartition": {
            "partition_const_loop": True,
            "no_unroll_loop_with_extent_one": True,
        },
        "tir.UnrollLoop": {
            "auto_max_depth": -1
        },
        "tir.noalias": True,
        "tir.debug_keep_trivial_loop": True,
    }
    # Merge two configs
    curr_cfg = {**curr_cfg, **tir_compiler_cfg}

    sch = sch.normalize()

    with tvm.transform.PassContext(config=curr_cfg):
        mod = schedule_to_module(sch, args, name)

        mod = tvm.tir.transform.Simplify()(mod)
        mod = ethosu_passes.RemoveConcatenates()(mod)
        mod = tvm.tir.transform.InjectRollingBuffer()(mod)
        mod = tvm.tir.transform.StorageFlatten(64)(mod)
        mod = tvm.tir.transform.UnrollLoop()(mod)
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.LoopPartition()(mod)
        mod = ethosu_passes.RemoveZeroStores()(mod)
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod = ethosu_passes.ReplaceOperators()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
        mod = ethosu_passes.HoistAllocates()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        #  MergeConstant pass currently does not support striped schedules.
        #  It requires further investigation.
        if not util.is_striping_enabled():
            mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
        mod = ethosu_passes.CopyComputeReordering()(mod)

        # When striping is enabled and if storage_rewrite is not run
        # the striping results in incorrect code generation. This needs
        # further investigation. Until such a time that is fixed, disable_storage_rewrite
        # user directive will be overridden if striping is enabled.
        disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite",
                                               False)
        if not disable_storage_rewrite or util.is_striping_enabled():
            mod = tvm.tir.transform.StorageRewrite()(mod)

        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod = ethosu_passes.AnnotateAllocates()(mod)
        mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(
            const_dict)(mod)
    return mod, const_dict
예제 #22
0
def lower_ethosu(sch, args, const_dict, name="main"):
    """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.

    The resulting TIR module will contain a single function
    that consists of a sequence of tir.call_extern to NPU
    operations.

    Parameters
    ----------
    sch : tvm.te.Schedule
        The schedule to be lowered.
    args : Union[list of tvm.te.Tensor, TEGraph]
        The input/output tensors.
    const_dict : dict of int to numpy.ndarray
        The constant dictionary.
    name : str, optional
        The name of the lowered primitive function.

    Returns
    -------
    mod : tvm.IRModule
        The lowered TIR module.
    const_dict : dict of int to numpy.ndarray
        The modified constant dictionary.

    """
    if not isinstance(args, list):
        args = list(args.inputs) + list(args.outputs)
    # config setup
    curr_pass_ctx = tvm.ir.transform.PassContext.current()
    curr_cfg = dict()
    for key, value in curr_pass_ctx.config.items():
        curr_cfg[key] = value
    tir_compiler_cfg = {
        "tir.LoopPartition": {
            "partition_const_loop": True,
            "no_unroll_loop_with_extent_one": True,
        },
        "tir.UnrollLoop": {
            "auto_max_depth": -1
        },
        "tir.noalias": True,
        "tir.debug_keep_trivial_loop": True,
    }
    # Merge two configs
    curr_cfg = {**curr_cfg, **tir_compiler_cfg}

    sch = sch.normalize()

    with tvm.transform.PassContext(config=curr_cfg):
        mod = schedule_to_module(sch, args, name)

        mod = tvm.tir.transform.Simplify()(mod)
        mod = ethosu_passes.RemoveConcatenates()(mod)
        mod = tvm.tir.transform.StorageFlatten(64)(mod)
        mod = tvm.tir.transform.UnrollLoop()(mod)
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.LoopPartition()(mod)
        mod = ethosu_passes.RemoveZeroStores()(mod)
        mod = tvm.tir.transform.Simplify()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod = ethosu_passes.ReplaceOperators()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
        mod = tvm.tir.transform.StorageRewrite()(mod)
        mod = tvm.tir.transform.RemoveNoOp()(mod)
        mod = ethosu_passes.AnnotateAllocates()(mod)
    return mod, const_dict
예제 #23
0
def test_inplace_rule3():
    # Test Buffer
    scope_tb = "local_TB3"
    max_bits = 1024 * 1024 * 1024

    register_mem(scope_tb, max_bits)
    m = 10
    B0 = te.placeholder((m, ), name="B0")
    B1 = te.placeholder((m, ), name="B1")
    B2 = te.placeholder((m, ), name="B2")
    B3 = te.placeholder((m, ), name="B3")
    B4 = te.placeholder((m, ), name="B4")
    B5 = te.placeholder((m, ), name="B5")

    B6 = te.compute((m, ), lambda i: B1[i] * B5[i], name="B6")
    B7 = te.compute((m, ), lambda i: B2[i] * B4[i], name="B7")
    B8 = te.compute((m, ), lambda i: B6[i] - B7[i], name="B8")

    B9 = te.compute((m, ), lambda i: B2[i] * B3[i], name="B9")
    B10 = te.compute((m, ), lambda i: B0[i] * B5[i], name="B10")
    B11 = te.compute((m, ), lambda i: B9[i] - B10[i], name="B11")

    B12 = te.compute((m, ), lambda i: B0[i] * B4[i], name="B12")
    B13 = te.compute((m, ), lambda i: B1[i] * B3[i], name="B13")
    B14 = te.compute((m, ), lambda i: B12[i] - B13[i], name="B14")

    B = te.compute((m, ), lambda i: B8[i] * B11[i] + B14[i], name="B")
    s = te.create_schedule(B.op)

    B1L = s.cache_read(B1, scope_tb, [B6, B13])
    B5L = s.cache_read(B5, scope_tb, [B6, B10])
    B2L = s.cache_read(B2, scope_tb, [B7, B9])
    B4L = s.cache_read(B4, scope_tb, [B7, B12])
    B3L = s.cache_read(B3, scope_tb, [B9, B13])
    B0L = s.cache_read(B0, scope_tb, [B10, B12])

    B8L = s.cache_write(B8, scope_tb)
    B11L = s.cache_write(B11, scope_tb)
    B14L = s.cache_write(B14, scope_tb)
    B6L = s.cache_write(B6, scope_tb)
    B7L = s.cache_write(B7, scope_tb)
    B9L = s.cache_write(B9, scope_tb)
    B10L = s.cache_write(B10, scope_tb)
    B12L = s.cache_write(B12, scope_tb)
    B13L = s.cache_write(B13, scope_tb)

    s[B12].compute_inline()
    s[B13].compute_inline()
    s[B8].compute_inline()
    s[B11].compute_inline()
    s[B14].compute_inline()
    s[B6].compute_inline()
    s[B7].compute_inline()
    s[B9].compute_inline()
    s[B10].compute_inline()

    s = s.normalize()
    mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B])
    mod = tvm.tir.transform.StorageFlatten(64)(mod)

    mod = tvm.tir.transform.Simplify()(mod)
    mod = tvm.tir.transform.StorageRewrite()(mod)
    stmt = mod["main"].body

    # verify only have one allocations.
    # verify inplace folding works
    def verify(n):
        if isinstance(n, tvm.tir.Allocate):
            assert n.extents[0].value == 70

    tvm.tir.stmt_functor.post_order_visit(stmt, verify)