Exemplo n.º 1
0
def search_common(workload=matmul_auto_scheduler_test,
                  target="llvm",
                  search_policy='empty',
                  seed=random.randint(1, 1 << 30),
                  runner='local',
                  cost_model=auto_scheduler.RandomModel(),
                  num_measure_trials=2,
                  init_search_callbacks=None):
    print("Test %s schedule search with the default search policy" % (target))

    random.seed(seed)
    N = 128
    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create(target)
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == 'empty':
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == 'sketch':
            search_policy = auto_scheduler.SketchPolicy(
                task, init_search_callbacks=init_search_callbacks)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            runner=runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, workload_key, target)

        print("==== Python Code ====")
        print(dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()
Exemplo n.º 2
0
def test_correctness_layout_rewrite_insert_transform_stage():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
            InsertTransformStage)

        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(),
                                    args_ref[0].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[1].asnumpy(),
                                    args_ref[1].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(),
                                    args_ref[2].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        del measure_ctx
    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)

######################################################################
# Using the record file
# ^^^^^^^^^^^^^^^^^^^^^
# During the search, all measuremnt records are dumpped into the record
# file "conv2d.json". The measurement records can be used to re-apply search results,
# resume the search, and perform other analyses.

######################################################################
# Here is an example where we load the best schedule from a file,
# print the equivalent python schedule API, and build the binary again.

# Load the measuremnt record for the best schedule
inp, res = auto_scheduler.load_best(log_file, task.workload_key)

# Print equivalent python schedule API. This can be used for debugging and
# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
print(task.compute_dag.print_python_code_from_state(inp.state))

# Rebuild the binary. This shows how you can apply the best schedule from a
# log file without reruning the search again.
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args, target)

######################################################################
# A more complicated example is to resume the search.
# In this case, we need to create the search policy and cost model by ourselves
# and resume the status of search policy and cost model with the log file.
Exemplo n.º 4
0
    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000)
)

######################################################################
# Using the record file
# ^^^^^^^^^^^^^^^^^^^^^
# During the search, all measuremnt records are dumpped into the record
# file "conv2d.json". The measurement records can be used to re-apply search results,
# resume the search, and perform other analyses.

######################################################################
# Here is an example where we load the best schedule from a file,
# print the equivalent python schedule API, and build the binary again.

# Load the measuremnt record for the best schedule
inp, res = auto_scheduler.load_best("conv2d.json", task.workload_key)

# Print equivalent python schedule API. This can be used for debugging and
# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
print(task.compute_dag.print_python_code_from_state(inp.state))

# Rebuild the binary. This shows how you can apply the best schedule from a
# log file without reruning the search again.
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args, target)

######################################################################
# A more complicated example is to resume the search.
# In this case, we need to create the search policy and cost model by ourselves
# and resume the status of search policy and cost model with the log file.
Exemplo n.º 5
0
tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)

######################################################################
# Using the record file
# ^^^^^^^^^^^^^^^^^^^^^
# During the search, all measuremnt records are dumpped into the record
# file "matmul.json". The measurement records can be used to re-apply search results,
# resume the search, and perform other analyses.

######################################################################
# Here is an example where we load the best schedule from a file,
# print the equivalent python schedule API, and build the binary again.

# Load the measuremnt record for the best schedule
inp, res = auto_scheduler.load_best("matmul.json", task.workload_key)

# Print equivalent python schedule API. This can be used for debugging and
# learning the behavior of the auto-scheduler.
print(task.compute_dag.print_python_code_from_state(inp.state))

# Rebuild the binary. This shows how you can apply the best schedule from a
# log file without reruning the search again.
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args)

######################################################################
# A more complicated example is to resume the search.
# In this case, we need to create the search policy and cost model by ourselves
# and resume the status of search policy and cost model with the log file.
# In the example below we resume the status and do more 5 trials.
Exemplo n.º 6
0
def test_layout_rewrite_correctness():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner="local",
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state,
                                                     layout_rewrite=False)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = ([
                2 + base,
                4 + base,
            ] + list(range(base + 2)) + [
                3 + base,
                5 + base,
            ])
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        np.testing.assert_allclose(np_args[0], np_args_ref[0])
        np.testing.assert_allclose(np_args[2], np_args_ref[2])
Exemplo n.º 7
0
def search_common(
        workload=matmul_auto_scheduler_test,
        target="llvm",
        search_policy="sketch",
        seed=0,
        runner="local",
        num_measure_trials=100,
        cost_model=auto_scheduler.RandomModel(),
        init_search_callbacks=None,
):
    print("Test search policy '%s' for '%s'" % (search_policy, target))

    random.seed(seed)
    N = 128
    target = tvm.target.Target(target)
    task = auto_scheduler.create_task(workload, (N, N, N), target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == "empty":
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == "sketch":
            search_policy = auto_scheduler.SketchPolicy(
                task,
                program_cost_model=cost_model,
                init_search_callbacks=init_search_callbacks)
        else:
            raise ValueError("Invalid policy: " + search_policy)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            num_measures_per_round=2,
            early_stopping=1,
            runner=runner,
            verbose=2,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, task.workload_key,
                                            target)

        print("==== Python Code ====")
        print(task.compute_dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = task.compute_dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()