def test_split_int64_extent_with_mixed_factors(): def _create_prim_func(): m = te.const(384, "int64") A = te.placeholder((m,), name="A", dtype="float32") B = te.compute((m,), lambda i: A[i] + 1, name="B") return te.create_prim_func([A, B]) mod = _create_prim_func() sch = tir.Schedule(mod, debug_mask="all") (i,) = sch.get_loops(sch.get_block("B")) sch.split( i, factors=[ te.const(1, "int64"), te.const(512, "int32"), ], )
def _create_prim_func(): m = te.const(12, "int64") A = te.placeholder((m,), name="A", dtype="float32") B = te.compute((m,), lambda i: A[i] + 1, name="B") return te.create_prim_func([A, B])
def _create_prim_func(): n = te.const(16, "int32") m = te.const(32, "int64") A = te.placeholder((n, m), name="A", dtype="int32") B = te.compute((n, m), lambda i, j: A[i, j] + 1, name="B") return te.create_prim_func([A, B])