Exemplo n.º 1
0
def test_estimate_flop():
    N = 512
    A, B, C = matmul_auto_scheduler_test(N, N, N)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5

    D = topi.nn.relu(C)
    dag = auto_scheduler.ComputeDAG([A, B, D])
    assert abs(dag.flop_ct - (2 * N**3 + N * N)) < 0.5

    # should not count the comparison operations in padding
    E = topi.nn.pad(C, [1, 1])
    dag = auto_scheduler.ComputeDAG([A, B, E])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5

    F = te.compute((N, N),
                   lambda i, j: E[i, j],
                   name="F",
                   attrs={"FLOP": 1234})
    dag = auto_scheduler.ComputeDAG([A, B, F])
    assert abs(dag.flop_ct - (2 * N**3 + 1234)) < 0.5

    A = te.placeholder((N, N), dtype="float32", name="A")
    F = te.compute((N, N),
                   lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0))
    dag = auto_scheduler.ComputeDAG([A, F])
    assert abs(dag.flop_ct - N**2) < 0.5
def test_estimate_flop():
    N = 512
    A, B, C = matmul_auto_scheduler_test(N, N, N)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5

    D = topi.nn.relu(C)
    dag = auto_scheduler.ComputeDAG([A, B, D])
    assert abs(dag.flop_ct - 2 * N**3 - N * N) < 0.5

    # should not count the comparison operations in padding
    D = topi.nn.pad(C, [1, 1])
    dag = auto_scheduler.ComputeDAG([A, B, D])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5
def generate_sketches(workload_func, args, target, print_for_debug=False):
    workload_key = auto_scheduler.make_workload_key(workload_func, args)
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.create(target))
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    return policy.generate_sketches(print_for_debug)
Exemplo n.º 4
0
def search_common(workload=matmul_auto_scheduler_test,
                  target="llvm",
                  search_policy='empty',
                  seed=random.randint(1, 1 << 30),
                  runner='local',
                  cost_model=auto_scheduler.RandomModel(),
                  num_measure_trials=2,
                  init_search_callbacks=None):
    print("Test %s schedule search with the default search policy" % (target))

    random.seed(seed)
    N = 128
    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create(target)
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == 'empty':
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == 'sketch':
            search_policy = auto_scheduler.SketchPolicy(
                task, init_search_callbacks=init_search_callbacks)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            runner=runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, workload_key, target)

        print("==== Python Code ====")
        print(dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()
Exemplo n.º 5
0
def test_split_fuse_reorder():
    A, B, C = matmul_auto_scheduler_test(512, 512, 512)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    s0 = dag.get_init_state()
    i, j, k = s0[C].iters

    assert i.range.extent == 512

    io, ii = s0.split(C, i, [16])
    assert s0[C].iters[0] == io
    assert s0[C].iters[1] == ii
    assert io.range.extent == 32
    assert ii.range.extent == 16

    jo, ji = s0.split(C, j, [8])
    assert jo.range.extent == 64
    assert ji.range.extent == 8

    s0.reorder(C, [io, jo, k, ji, ii])
    assert s0[C].iters[2].range.extent == 512

    fused_it = s0.fuse(C, [io, jo])
    assert fused_it.range.extent == 2048

    s1 = dag.get_init_state()
    i, j, _ = s1[C].iters
    i1, i2, i3 = s1.split(C, i, [8, 2])
    j1, j2, j3 = s1.split(C, j, [32, 8], False)
    assert s1[C].iters[0].range.extent == 32
    assert s1[C].iters[1].range.extent == 8
    assert s1[C].iters[2].range.extent == 2
    assert s1[C].iters[3].range.extent == 32
    assert s1[C].iters[4].range.extent == 8
    assert s1[C].iters[5].range.extent == 2
def test_record_compute_at_root_inline_cache_read_write():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    AA = topi.nn.relu(A)
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512),
                   lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]),
                   name="C")

    dag = auto_scheduler.ComputeDAG([A, B, C])
    s = dag.get_init_state()

    # Cache Write
    C_shared = s.cache_write(C, "shared")
    # Compute At
    s.compute_at(C_shared, C, s[C].iters[0])
    # Cache Read
    B_global = s.cache_read(B, "global", [C_shared])
    s.compute_at(B_global, C_shared, s[C_shared].iters[2])
    # Compute Inline
    s.compute_inline(AA)
    # Compute Root
    s.compute_root(C_shared)

    record_common(dag, s)
def test_record_follow_split_follow_fused_split():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512),
                   lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
                   name="C")
    D = topi.nn.relu(C)
    E = topi.nn.relu(D)

    dag = auto_scheduler.ComputeDAG([A, B, E])
    s = dag.get_init_state()

    # Follow Split
    s.split(C, s[C].iters[0], [4, 2, 8, 4], True)
    split_step0 = len(s.transform_steps) - 1
    s.follow_split(C, s[C].iters[5], split_step0, 4)
    # Follow Fused Split
    its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True)
    split_step1 = len(s.transform_steps) - 1
    its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True)
    split_step2 = len(s.transform_steps) - 1
    its = []
    for i0, i1 in zip(its0, its1):
        its.append(i0)
        its.append(i1)
    for i in range(0, 5):
        s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]])
    s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True)

    record_common(dag, s)
Exemplo n.º 8
0
def test_record_split_reorder_fuse_annotation():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")

    dag = auto_scheduler.ComputeDAG([A, B, C])
    s = dag.get_init_state()

    # Split
    its0 = s.split(C, s[C].iters[0], [4, 8, 8])
    its1 = s.split(C, s[C].iters[4], [8, 4, 4])
    # Reorder
    s.reorder(
        C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]
    )
    # Fuse
    s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
    # Parallel
    s.parallel(C, s[C].iters[0])
    # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
    s.bind(C, s[C].iters[1], "blockIdx.x")
    s.bind(C, s[C].iters[2], "threadIdx.z")
    s.bind(C, s[C].iters[3], "vthread")
    # Unroll
    s.unroll(C, s[C].iters[4])
    # Vectorize
    s.vectorize(C, s[C].iters[6])

    record_common(dag, s)
Exemplo n.º 9
0
def test_record():
    if not tvm.runtime.enabled("llvm"):
        return

    A = te.placeholder((512, 512), name='A')
    B = te.placeholder((512, 512), name='B')
    k = te.reduce_axis((0, 512), name='k')
    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
    D = topi.nn.relu(C)
    k = te.reduce_axis((0, 512), name='k')
    E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C')
    F = topi.nn.relu(E)

    dag = auto_scheduler.ComputeDAG([A, B, F])
    s = dag.get_init_state()

    # Split
    its0 = s.split(C, s[C].iters[0], [4, 8, 8])
    its1 = s.split(C, s[C].iters[4], [8, 4, 4])
    # Reorder
    s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8],
                  its1[3]])
    # Fuse
    s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
    # Compute at
    s.split(F, s[F].iters[0], [2])
    s.compute_at(E, F, s[F].iters[0])
    # Compute inline
    s.compute_inline(D)
    # Compute root
    s.compute_root(D)
    # Parallel
    s.parallel(C, s[C].iters[0])
    # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
    s.bind(C, s[C].iters[1], "blockIdx.x")
    s.bind(C, s[C].iters[2], "threadIdx.z")
    s.bind(C, s[C].iters[3], "vthread")
    # Unroll
    s.unroll(C, s[C].iters[4])
    # Vectorize
    s.vectorize(C, s[C].iters[6])

    target = tvm.target.create("llvm")
    task = auto_scheduler.SearchTask(dag, "test", target)

    inp = auto_scheduler.measure.MeasureInput(task, s)
    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)

    with tempfile.NamedTemporaryFile() as fp:
        auto_scheduler.save_records(fp.name, [inp], [res])

        log_reader = auto_scheduler.RecordReader(fp.name)
        inputs, results = log_reader.read_lines()
        assert len(inputs) == 1

        s1 = dag.infer_bound_from_state(s)
        s2 = dag.infer_bound_from_state(inputs[0].state)

        assert s1 == s2
        assert not (s1 == dag.get_init_state())
Exemplo n.º 10
0
def test_cpu_matmul():
    dag = auto_scheduler.ComputeDAG(matmul_auto_scheduler_test(512, 512, 512))
    s = dag.get_init_state()
    C = s.stage_ops[2]

    i, j, k = s[C].iters
    io, ii = s.split(C, i, [16])
    jo, ji = s.split(C, j, [8])
    s.reorder(C, [io, jo, k, ji, ii])
    s.vectorize(C, ji)
    s.parallel(C, io)
    s.parallel(C, jo)
    s.unroll(C, k)

    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target)
    names = auto_scheduler.feature.get_per_store_feature_names()
    fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0]

    stage_0 = fea[0]
    assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names))
    fea_dict = {}
    for name, value in zip(names, stage_0):
        fea_dict[name] = value

    for name in ["B0", "B1", "B2"]:
        if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0):
            c_name = name
        if fequal(fea_dict[name + ".acc_type.kRead"], 1.0):
            if fequal(fea_dict[name + ".stride"], 0.0):
                b_name = name
            else:
                a_name = name

    """
    lowered IR:
    
    Placeholder: A, B
    parallel i.0 (0,32)
      parallel j.0 (0,64)
        unroll k (0,512)
          vectorize j.1 (0,8)
            for i.1 (0,16)
              C...] = A[...] * B[...]
    """

    # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc.
    assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1))
    assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1))
    assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1))
    assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1))
    assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1))

    # check annotations
    assert fequal(fea_dict["unroll_num"], math.log2(1 + 1))
    # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0)
    assert fequal(fea_dict["vec_num"], math.log2(1 + 1))
    assert fequal(fea_dict["parallel_num"], math.log2(2 + 1))
    assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1))
def test_split_fuse_reorder_annotation():
    A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    s0 = dag.get_init_state()
    i, j, k = s0[C].iters

    assert i.range.extent == 512

    io, ii = s0.split(C, i, [16])
    assert s0[C].iters[0] == io
    assert s0[C].iters[1] == ii
    assert io.range.extent == 32
    assert ii.range.extent == 16

    jo, ji = s0.split(C, j, [8])
    assert jo.range.extent == 64
    assert ji.range.extent == 8

    s0.reorder(C, [io, jo, k, ji, ii])
    assert s0[C].iters[2].range.extent == 512

    fused_it = s0.fuse(C, [io, jo])
    assert fused_it.range.extent == 2048

    s1 = dag.get_init_state()
    i, j, _ = s1[C].iters
    i1, i2, i3 = s1.split(C, i, [8, 2])
    j1, j2, j3 = s1.split(C, j, [32, 8], False)
    assert s1[C].iters[0].range.extent == 32
    assert s1[C].iters[1].range.extent == 8
    assert s1[C].iters[2].range.extent == 2
    assert s1[C].iters[3].range.extent == 32
    assert s1[C].iters[4].range.extent == 8
    assert s1[C].iters[5].range.extent == 2

    res = s1.bind(C, i1, "blockIdx.x")
    assert res == s1[C].iters[0]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"]

    res = s1.bind(C, i2, "vthread")
    assert res == s1[C].iters[1]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"]

    res = s1.bind(C, i3, "threadIdx.y")
    assert res == s1[C].iters[2]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"]

    res = s1.parallel(C, j1)
    assert res == s1[C].iters[3]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"]

    res = s1.unroll(C, j2)
    assert res == s1[C].iters[4]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"]

    res = s1.vectorize(C, j3)
    assert res == s1[C].iters[5]
    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
def test_compute_at_root_inline():
    dag = auto_scheduler.ComputeDAG(
        conv2d_nchw_bn_relu_auto_scheduler_test(
            N=1, H=224, W=224, CI=3, CO=64, kernel_size=7, strides=2, padding=3
        )
    )
    s0 = dag.get_init_state()

    # data, padding, kernel = 0, 1, 2
    conv = s0.stage_ops[3]
    # bias = 4
    bias_add = s0.stage_ops[5]
    # bn_scale = 6
    bn_mul = s0.stage_ops[7]
    # bn_offset = 8
    bn_add = s0.stage_ops[9]
    relu = s0.stage_ops[10]

    s0.compute_inline(bn_add)
    assert s0[bn_add].compute_at == 1

    s0.compute_inline(bn_mul)
    assert s0[bn_mul].compute_at == 1

    s0.compute_inline(bias_add)
    assert s0[bias_add].compute_at == 1

    assert s0[conv].iters[0].range.extent == 1
    assert s0[conv].iters[1].range.extent == 64
    assert s0[conv].iters[2].range.extent == 112
    assert s0[conv].iters[3].range.extent == 112
    assert s0[conv].iters[4].range.extent == 3
    assert s0[conv].iters[5].range.extent == 7
    assert s0[conv].iters[6].range.extent == 7
    s0.compute_at(conv, relu, s0[relu].iters[2])
    assert s0[conv].compute_at == 2
    s0 = dag.infer_bound_from_state(s0)
    assert s0[conv].iters[0].range.extent == 1
    assert s0[conv].iters[1].range.extent == 1
    assert s0[conv].iters[2].range.extent == 1
    assert s0[conv].iters[3].range.extent == 112
    assert s0[conv].iters[4].range.extent == 3
    assert s0[conv].iters[5].range.extent == 7
    assert s0[conv].iters[6].range.extent == 7

    s0.compute_root(bn_mul)
    assert s0[bn_mul].compute_at == 0

    s0.compute_root(conv)
    assert s0[conv].compute_at == 0
    s0 = dag.infer_bound_from_state(s0)
    assert s0[conv].iters[0].range.extent == 1
    assert s0[conv].iters[1].range.extent == 64
    assert s0[conv].iters[2].range.extent == 112
    assert s0[conv].iters[3].range.extent == 112
    assert s0[conv].iters[4].range.extent == 3
    assert s0[conv].iters[5].range.extent == 7
    assert s0[conv].iters[6].range.extent == 7
Exemplo n.º 13
0
def test_invalid_compute_dag():
    failed = False
    try:
        A, B = invalid_compute_definition()
        dag = auto_scheduler.ComputeDAG([A, B])
    except tvm.TVMError as e:
        failed = True

    assert failed
Exemplo n.º 14
0
def test_stage_order():
    N = 512
    A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
    sch = te.create_schedule([D.op, E.op])
    (D_local, ) = sch.cache_write([D], "local")
    (E_local, ) = sch.cache_write([E], "local")
    sch.cache_read(A, "shared", [D_local])
    sch.cache_read(B, "shared", [D_local])
    sch.cache_read(A, "shared", [E_local])
    sch.cache_read(C, "shared", [E_local])

    dag = auto_scheduler.ComputeDAG(sch)
    stage_ops_1 = dag.get_init_state().stage_ops

    # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute
    assert len(stage_ops_1) == 11

    # Cache read stage should follow the source stage
    for idx, op in enumerate(stage_ops_1):
        if op.name == "A":
            assert (stage_ops_1[idx + 1].name == "A.d.shared"
                    and stage_ops_1[idx + 2].name == "A.shared")
        elif op.name in ["B", "C"]:
            assert stage_ops_1[idx + 1].name == "%s.shared" % op.name

    # Apply the same schedule to Ansor state and it should have the same stage order
    dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
    state = dag.get_init_state()

    D_local = state.cache_write(D, "local")
    E_local = state.cache_write(E, "local")
    state.cache_read(A, "shared", [D_local])
    state.cache_read(B, "shared", [D_local])
    state.cache_read(A, "shared", [E_local])
    state.cache_read(C, "shared", [E_local])

    stage_ops_2 = state.stage_ops
    assert len(stage_ops_1) == len(stage_ops_2)

    # Cache read stage should follow the source stage
    for op1, op2 in zip(stage_ops_1, stage_ops_2):
        assert op1.name == op2.name
Exemplo n.º 15
0
def get_tiled_matmul():
    A, B, C = matmul_auto_scheduler_test(512, 512, 512)
    dag = auto_scheduler.ComputeDAG([A, B, C])

    s0 = dag.get_init_state()
    its0 = s0.split(C, s0[C].iters[0], [4, 8, 8])
    its1 = s0.split(C, s0[C].iters[4], [8, 4, 4])
    s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3],
                   s0[C].iters[8]])

    return dag, s0
def test_estimate_flop():
    N = 512
    A, B, C = matmul_auto_scheduler_test(N, N, N)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5

    D = topi.nn.relu(C)
    dag = auto_scheduler.ComputeDAG([A, B, D])
    assert abs(dag.flop_ct - (2 * N**3 + N * N)) < 0.5

    # should not count the comparison operations in padding
    E = topi.nn.pad(C, [1, 1])
    dag = auto_scheduler.ComputeDAG([A, B, E])
    assert abs(dag.flop_ct - 2 * N**3) < 0.5

    F = te.compute((N, N),
                   lambda i, j: E[i, j],
                   name='F',
                   attrs={"FLOP": 1234})
    dag = auto_scheduler.ComputeDAG([A, B, F])
    assert abs(dag.flop_ct - (2 * N**3 + 1234)) < 0.5
Exemplo n.º 17
0
def test_apply_steps_with_layout_rewrite_corner_case():
    A, B, C = matmul_auto_scheduler_test(1, 1, 1)
    dag = auto_scheduler.ComputeDAG([A, B, C])

    s = dag.get_init_state()

    s.compute_root(C)
    i_j_fused = s.fuse(C, [s[C].iters[0], s[C].iters[1]])
    s.parallel(C, i_j_fused)

    _, bufs = dag.apply_steps_from_state(
        s, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
    )
def test_random_model():
    if not tvm.runtime.enabled("llvm"):
        return
    N = 128
    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create('llvm')
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    model = auto_scheduler.RandomModel()
    model.update([], [])
    scores = model.predict(task, [dag.init_state, dag.init_state])
    assert len(scores) == 2
def test_rfactor():
    A, B, C = matmul_auto_scheduler_test(8, 8, 512)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    s0 = dag.get_init_state()

    ko, ki = s0.split(C, s0[C].iters[2], [16])

    s1 = s0.copy()
    C_r = s1.rfactor(C, ko, 2)
    """
        Placeholder: A, B
        for i (0,8)
          for j (0,8)
            for k_o (0,32)
              for k_i (0,16)
                C.rf = ...
        for ax0 (0,8)
          for ax1 (0,8)
            for k_o_v (0,32)
              C.repl = ...
    """
    assert s1[C_r].iters[0].range.extent == 8
    assert s1[C_r].iters[1].range.extent == 8
    assert s1[C_r].iters[2].range.extent == 32
    assert s1[C_r].iters[3].range.extent == 16
    assert s1[C].iters[0].range.extent == 8
    assert s1[C].iters[1].range.extent == 8
    assert s1[C].iters[2].range.extent == 32

    s2 = s0.copy()
    C_r = s2.rfactor(C, ki, 2)
    """
        Placeholder: A, B
        for i (0,8)
          for j (0,8)
            for k_i (0,16)
              for k_o (0,32)
                C.rf = ...
        for ax0 (0,8)
          for ax1 (0,8)
            for k_i_v (0,16)
              C.repl = ...
    """
    assert s2[C_r].iters[0].range.extent == 8
    assert s2[C_r].iters[1].range.extent == 8
    assert s2[C_r].iters[2].range.extent == 16
    assert s2[C_r].iters[3].range.extent == 32
    assert s2[C].iters[0].range.extent == 8
    assert s2[C].iters[1].range.extent == 8
    assert s2[C].iters[2].range.extent == 16
def get_sample_records(number):
    """Generate random a list of random MeasureInput and MeasureResult pairs"""
    N = 128
    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create('llvm')
    task = auto_scheduler.SearchTask(dag, workload_key, target)
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    states = policy.sample_initial_population(number)

    inputs = [auto_scheduler.MeasureInput(task, s) for s in states]
    results = [auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0)
               for _ in range(len(inputs))]

    return task, dag, inputs, results
def test_mutate_tile_size():
    """
    The test case initializes evo search with a batch of "bad" states and check whether
    the search algorithm can find "good" states by mutating the "bad" states.

    This unit test has been tested with 1,000 runs with no failures, meaning that
    the failure rate is less than 0.1%.
    """
    class MockCostModel(PythonBasedModel):
        """A mock cost model that rates 1 only for the states with tile_k=2."""
        @staticmethod
        def is_good_state(state):
            for line in str(state).split("\n"):
                if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                    return True
            return False

        def predict(self, task, states):
            scores = []
            for state in states:
                scores.append(1 if self.is_good_state(state) else 0)
            return scores

    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test,
                                                    (10, 10, 4))
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.Target("llvm"))
    policy = auto_scheduler.SketchPolicy(task,
                                         program_cost_model=MockCostModel(),
                                         verbose=0)
    states = policy.sample_initial_population()[:50]

    bad_states = []
    for state in states:
        if not MockCostModel.is_good_state(state):
            bad_states.append(state)

    new_states = policy.evolutionary_search(bad_states, 50)
    found = False
    for state in new_states:
        if MockCostModel.is_good_state(state):
            found = True
            break
    assert found
Exemplo n.º 22
0
def test_cpu_fusion():
    def fusion_test(N, M):
        A = te.placeholder((N, M), name="A")
        B = te.compute((N, M), lambda i, j: A[i][j], name="B")
        C = te.compute((N, M), lambda i, j: B[i][j], name="C")
        return [A, B, C]

    dag = auto_scheduler.ComputeDAG(fusion_test(64, 32))
    s = dag.get_init_state()
    s.compute_at(1, 2, s.stages[2].iters[1])

    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(compute_dag=dag,
                                     workload_key="test",
                                     target=target)
    names = auto_scheduler.feature.get_per_store_feature_names()
    fea = auto_scheduler.feature.get_per_store_features_from_states([s],
                                                                    task)[0]
    """
    lowered IR:

    Placeholder: A
    for i (0,64)
        for j (0,32)
            for ii (1)
                for jj (1)
                    B[...] = A[...]
            C[...] = B[...]
    """

    # check reuse distance and reuse type after fusion
    found = False
    for stage_fea in fea:
        for i, (name, value) in enumerate(zip(names, stage_fea)):
            if "reuse_type.kSerialMultipleReadWrite" in name and value > 0.5:
                # reuse distance in #iter
                assert fequal(stage_fea[i + 2], 1.0)
                # reuse distance in bytes
                assert fequal(stage_fea[i + 3], math.log2(16 + 1))
                found = True
    assert found
Exemplo n.º 23
0
def test_record_pragma_storage_align_rfactor():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")

    dag = auto_scheduler.ComputeDAG([A, B, C])
    s = dag.get_init_state()

    # Rfactor
    ko, _ = s.split(C, s[C].iters[2], [16])
    s.rfactor(C, ko, 2)
    # Pragma
    s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64")
    # StorageAlign
    s.storage_align(C, s[C].iters[-1], 8, 4)

    record_common(dag, s)
Exemplo n.º 24
0
def test_follow_split_follow_fused_split():
    A, B, C = matmul_auto_scheduler_test(512, 512, 512)
    dag = auto_scheduler.ComputeDAG([A, B, C])
    s0 = dag.get_init_state()

    C_global = s0.cache_write(C, "global")
    its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True)
    split_step0 = len(s0.transform_steps) - 1
    for level in range(1, 6):
        tmp = s0.copy()
        tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level)
        for i in range(0, level):
            assert tmp[C].iters[i].range.extent == \
                   tmp[C_global].iters[i].range.extent

    its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8])
    split_step1 = len(s0.transform_steps) - 1
    its = []
    for i0, i1 in zip(its0, its1):
        its.append(i0)
        its.append(i1)
    s0.reorder(C, its)
    for i in range(0, 5):
        s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]])

    for level in range(0, 4):
        tmp = s0.copy()
        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
                               [split_step0, split_step1], level, False)
        assert tmp[C].iters[level + 1].range.extent == \
               tmp[C_global].iters[0].range.extent

    for level in range(0, 4):
        tmp = s0.copy()
        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
                               [split_step0, split_step1], level, True)
        assert tmp[C].iters[level + 1].range.extent == \
               tmp[C_global].iters[1].range.extent
def test_evo_search():
    """Test evolutionary search. Since we cannot mock random number generator,
    we mocked the cost model to manually guide the evo search. If evo search works
    as expected, it should find the target state after a sufficient number of iterations.
    This unit test has been tested with 1,000 runs with no failures, meaning that
    the failure rate is less than 0.1%.
    """
    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test,
                                                    (10, 10, 4))
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.Target("llvm"))
    policy = auto_scheduler.SketchPolicy(task,
                                         schedule_cost_model=MockCostModel(),
                                         verbose=0)
    states = policy.sample_initial_population(50)
    pruned_states = []
    for state in states:
        found = False
        for line in str(state).split("\n"):
            # Remove all tile_k=2 states and expect evo search will fine them.
            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                found = True
                break
        if not found:
            pruned_states.append(state)

    new_states = policy.evolutionary_search(pruned_states, 50)
    found = False
    for state in new_states:
        for line in str(state).split("\n"):
            # Check if evo search found at least one state with tile_k=2.
            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                found = True
                break
        if found:
            break
    assert found
def test_dag_measure_local_builder_runner():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512),
                   lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
                   name="C")
    D = topi.nn.relu(C)
    E = topi.nn.relu(D)

    tensors = [A, B, E]
    dag = auto_scheduler.ComputeDAG(tensors)
    key = workload_registry.register_workload_tensors(dag.workload_key(),
                                                      tensors)
    transfer_data = workload_registry.serialize_workload_registry_entry(key)
    f_data = pickle.dumps(transfer_data)
    f_new = pickle.loads(f_data)
    del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
    workload_registry.deserialize_workload_registry_entry(f_new)

    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(compute_dag=dag,
                                     workload_key=key,
                                     target=target)

    for enable_cpu_cache_flush in [True, False]:
        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
        local_builder = auto_scheduler.LocalBuilder()
        local_runner = auto_scheduler.LocalRunner(
            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)

        bress = local_builder.build([minp])
        assert bress[0].error_no == 0
        mress = local_runner.run([minp], bress)
        assert mress[0].error_no == 0
def test_cache_read_write():
    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)

    data = te.placeholder((N, CI, H, W), name="Data")
    kernel_data = te.placeholder((CO, CI, KH, KW), name="Kernel_data")
    k0, k1 = te.compute(
        kernel_data.shape,
        lambda *i: (kernel_data(*i) + 1, kernel_data(*i) / 2),
        name="Kernel_split",
    )
    kernel = te.compute(kernel_data.shape, lambda *i: k0(*i) + k1(*i), name="Kernel")
    conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
    relu = topi.nn.relu(conv)
    add = topi.add(data, relu)

    dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
    s0 = dag.get_init_state()

    pad_temp = s0.stage_ops[1]
    kernel_split = s0.stage_ops[3]

    # 0: init state
    ori_its = s0[add].iters
    its = s0.split(add, s0[add].iters[0], [2])
    s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
    s0.compute_inline(relu)

    # 1: simple cache_write with compute_at
    conv_global = s0.cache_write(conv, "global")
    s0.compute_at(conv_global, conv, s0[conv].iters[3])

    # 2: simple cache_read with compute_at
    kernel_global = s0.cache_read(kernel, "global", [conv_global])
    s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
    """
        Placeholder: Data, Kernel_data
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    s1 = dag.infer_bound_from_state(s0)
    assert s1[conv].iters[0].range.extent == 4
    assert s1[conv].iters[1].range.extent == 512
    assert s1[conv].iters[2].range.extent == 7
    assert s1[conv].iters[3].range.extent == 7
    assert s1[kernel_global].iters[0].range.extent == 1
    assert s1[kernel_global].iters[1].range.extent == 1
    assert s1[kernel_global].iters[2].range.extent == 3
    assert s1[kernel_global].iters[3].range.extent == 3
    assert s1[conv_global].iters[0].range.extent == 1
    assert s1[conv_global].iters[1].range.extent == 1
    assert s1[conv_global].iters[2].range.extent == 1
    assert s1[conv_global].iters[3].range.extent == 1
    assert s1[conv_global].iters[4].range.extent == 512
    assert s1[conv_global].iters[5].range.extent == 3
    assert s1[conv_global].iters[6].range.extent == 3

    # 3: two level cache_read with compute_at
    #    preparing for GPU's shared memory & local memory
    pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
    pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
    s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
    s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])

    # 4: cache_read with multi readers
    #    This stage cannot be compute at to its consumer
    s0.cache_read(data, "global", [pad_temp, add])
    """
        Placeholder: Data, Kernel_data
        for ax0 (0,4)
          for ax1 (0,512)
            for ax2 (0,7)
              for ax3 (0,7)
                Data.global = ...
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for ax0 (None)
                        for ax1 (None)
                          for ax2 (None)
                            for ax3 (None)
                              pad_temp.global = ...
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  pad_temp.global.shared = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    s1 = dag.infer_bound_from_state(s0)
    assert s1[conv].iters[0].range.extent == 4
    assert s1[conv].iters[1].range.extent == 512
    assert s1[conv].iters[2].range.extent == 7
    assert s1[conv].iters[3].range.extent == 7
    assert s1[kernel_global].iters[0].range.extent == 1
    assert s1[kernel_global].iters[1].range.extent == 1
    assert s1[kernel_global].iters[2].range.extent == 3
    assert s1[kernel_global].iters[3].range.extent == 3
    assert s1[conv_global].iters[0].range.extent == 1
    assert s1[conv_global].iters[1].range.extent == 1
    assert s1[conv_global].iters[2].range.extent == 1
    assert s1[conv_global].iters[3].range.extent == 1
    assert s1[conv_global].iters[4].range.extent == 512
    assert s1[conv_global].iters[5].range.extent == 3
    assert s1[conv_global].iters[6].range.extent == 3
    assert s1[pad_temp_global].iters[0].range.extent == 1
    assert s1[pad_temp_global].iters[1].range.extent == 512
    assert s1[pad_temp_global].iters[2].range.extent == 3
    assert s1[pad_temp_global].iters[3].range.extent == 3
    assert s1[pad_temp_shared].iters[0].range.extent == 1
    assert s1[pad_temp_shared].iters[1].range.extent == 1
    assert s1[pad_temp_shared].iters[2].range.extent == 3
    assert s1[pad_temp_shared].iters[3].range.extent == 3

    # 5: cache_write with multi outputs
    # TVM's cache_write actually has a bug with this case:
    #
    # After schedule.cache_write, TVM generate one new stage:
    #   From: kernel_data -> kernel_split -> kernel
    #   To:   kernel_data -> kernel_split_global -> kernel_split -> kernel
    #
    # But with topo sort analyse, we get:
    #  //   kernel_data -> kernel_split_global -> kernel_split -> kernel
    #         \                                                /
    #          ----------------> kernel_split ---------------->
    #
    # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case
    # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later.
    kernel_split_global = s0.cache_write(kernel_split, "global")
    """
        Placeholder: Data, Kernel_data
        for ax0 (0,4)
          for ax1 (0,512)
            for ax2 (0,7)
              for ax3 (0,7)
                Data.global = ...
        for i0 (0,4)
          for i1 (0,512)
            for i2 (0,9)
              for i3 (0,9)
                pad_temp = ...
        for i0_c (0,512)
          for i1_c (0,512)
            for i2_c (0,3)
              for i3_c (0,3)
                Kernel_split.global = ...
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        (******* Bug here, there should not be two kernel_split stage *******)
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel_split = ...
        (******* Bug here, there should not be two kernel_split stage *******)
        for i0 (0,512)
          for i1 (0,512)
            for i2 (0,3)
              for i3 (0,3)
                Kernel = ...
        for nn (0,4)
          for ff (0,512)
            for yy (0,7)
              for xx (0,7)
                for nn_c (None)
                  for ff_c (None)
                    for yy_c (None)
                      for ax0 (None)
                        for ax1 (None)
                          for ax2 (None)
                            for ax3 (None)
                              pad_temp.global = ...
                      for xx_c (None)
                        for rc (None)
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  Kernel.global = ...
                          for ax0 (None)
                            for ax1 (None)
                              for ax2 (None)
                                for ax3 (None)
                                  pad_temp.global.shared = ...
                          for ry (None)
                            for rx (None)
                              compute.global = ...
                compute = ...
        for ax0.0 (0,2)
          for ax1 (0,512)
            for ax0.1 (0,2)
              for ax2 (0,7)
                for ax3 (0,7)
                  T_add = ...
    """
    assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters)
    for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
        assert it0.range == it1.range
Exemplo n.º 28
0
def test_gpu_feature():
    # Use records to build a complicated GPU program
    json_records = "\n".join((
        """{"i": [["[\\"matmul_auto_scheduler_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""",
    ))

    # load states
    with tempfile.NamedTemporaryFile(mode="w") as f:
        f.write(json_records)
        f.flush()
        inputs, results = auto_scheduler.RecordReader(f.name).read_lines()

        inp = inputs[0]
        dag = auto_scheduler.ComputeDAG(inp.task.workload_key)
        task = auto_scheduler.SearchTask(
            dag,
            inp.task.workload_key,
            inp.task.target,
            None,
            auto_scheduler.HardwareParams(100000, 16, 64, 1 << 30, 1 << 30,
                                          1 << 30, 1 << 30, 1 << 30),
        )

        state = dag.infer_bound_from_state(inputs[0].state)
        fea = auto_scheduler.feature.get_per_store_features_from_states(
            [state], task)[0]
        names = auto_scheduler.feature.get_per_store_feature_names()

        # build feature dict
        fea_dicts = []
        for i in range(len(fea)):
            tmp_dict = {}
            for j in range(len(names)):
                tmp_dict[names[j]] = fea[i][j]
            fea_dicts.append(tmp_dict)
        """
        lowered IR:

        Placeholder: A, B
        blockIdx.x [email protected]@ (0,8)
          vthread [email protected]@ (0,4)
            threadIdx.x [email protected]@ (0,16)
              C.local auto_unroll: 1024
              for k.0 (0,256)
                for ax0@[email protected] (0,8)
                  threadIdx.x ax0@[email protected] (0,16)
                    B.shared = ...
                for ax0@[email protected] (0,64)
                  threadIdx.x ax0@[email protected] (0,16)
                    A.shared = ...
                for i_c.3 (0,32)
                  for k.2 (0,2)
                    for j_c.4 (0,16)
                      C.local = ...
              for i.3 (0,32)
                for j.3 (0,16)
                  C = ...
        """

        # check gpu-related features
        assert fequal(fea_dicts[0]["blockIdx_x_len"], math.log2(8 + 1))
        assert fequal(fea_dicts[0]["vthread_len"], math.log2(4 + 1))
        assert fequal(fea_dicts[1]["threadIdx_x_len"], math.log2(16 + 1))
        assert fequal(fea_dicts[0]["threadIdx_y_len"], math.log2(1 + 1))
        assert fequal(fea_dicts[2]["blockIdx_z_len"], math.log2(1 + 1))
        assert fequal(fea_dicts[0]["is_gpu"], 1.0)
Exemplo n.º 29
0
def test_stage_order():
    """Test if the stage order is preserved when recovering a DAG."""
    N = 512
    A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
    sch = te.create_schedule([D.op, E.op])
    (D_local, ) = sch.cache_write([D], "local")
    (E_local, ) = sch.cache_write([E], "local")
    sch.cache_read(A, "shared", [D_local])
    sch.cache_read(B, "shared", [D_local])
    sch.cache_read(A, "shared", [E_local])
    sch.cache_read(C, "shared", [E_local])

    dag = auto_scheduler.ComputeDAG(sch)
    stage_ops_1 = dag.get_init_state().stage_ops

    # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute
    assert len(stage_ops_1) == 11

    # Cache read stage should follow the source stage
    for idx, op in enumerate(stage_ops_1):
        if op.name == "A":
            assert (stage_ops_1[idx + 1].name == "A.d.shared"
                    and stage_ops_1[idx + 2].name == "A.shared")
        elif op.name in ["B", "C"]:
            assert stage_ops_1[idx + 1].name == "%s.shared" % op.name

    # Serialize and deserialize the ComputeDAG constructed by a schedule.
    loaded_dag = pickle.loads(pickle.dumps(dag))
    assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
    assert len(loaded_dag.get_init_state().stage_ops) == len(
        dag.get_init_state().stage_ops)

    # Apply the same schedule to Ansor state and it should have the same stage order
    dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
    state = dag.get_init_state()

    D_local = state.cache_write(D, "local")
    E_local = state.cache_write(E, "local")
    state.cache_read(A, "shared", [D_local])
    state.cache_read(B, "shared", [D_local])
    state.cache_read(A, "shared", [E_local])
    state.cache_read(C, "shared", [E_local])

    stage_ops_2 = state.stage_ops
    assert len(stage_ops_1) == len(stage_ops_2)

    # Cache read stage should follow the source stage
    for op1, op2 in zip(stage_ops_1, stage_ops_2):
        assert op1.name == op2.name

    # Serialize and deserialize the ComputeDAG constructed by a list of tensor ops.
    loaded_dag = pickle.loads(pickle.dumps(dag))
    assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
    assert len(loaded_dag.get_init_state().stage_ops) == len(
        dag.get_init_state().stage_ops)

    # Serialize and deserialize the search task.
    task = auto_scheduler.SearchTask(
        dag,
        json.dumps(("test-key", )),
        tvm.target.Target("llvm"),
        hardware_params=auto_scheduler.HardwareParams(100000, 16, 64, 0, 0, 0,
                                                      0, 0),
    )

    task2 = pickle.loads(pickle.dumps(task))
    assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY
    assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
    assert len(task.dag.get_init_state().stage_ops) == len(
        task2.dag.get_init_state().stage_ops)
    assert task.workload_key == task2.workload_key
    assert str(task.target) == str(task2.target)
    assert task.hardware_params.num_cores == task2.hardware_params.num_cores
    assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes
    assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes
def test_layout_rewrite_correctness():
    N = 128
    target = "llvm"
    workload = matmul_auto_scheduler_test
    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create(target)
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner='local',
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, workload_key, target)
        s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state,
                                                     layout_rewrite=False)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = [
                2 + base,
                4 + base,
            ] + list(range(base + 2)) + [
                3 + base,
                5 + base,
            ]
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s,
                         bufs,
                         target=inp.task.target,
                         target_host=inp.task.target_host)
        func_ref = tvm.build(s_ref, bufs_ref, target='llvm')

        ctx = tvm.context(str(inp.task.target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        np.testing.assert_allclose(np_args[0], np_args_ref[0])
        np.testing.assert_allclose(np_args[2], np_args_ref[2])