Exemplo n.º 1
0
def test_correctness_layout_rewrite_insert_transform_stage():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
            InsertTransformStage)

        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(),
                                    args_ref[0].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[1].asnumpy(),
                                    args_ref[1].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(),
                                    args_ref[2].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        del measure_ctx
Exemplo n.º 2
0
 def tune(self, n_trial, **kwargs):
     global GLOBAL_TUNER
     GLOBAL_TUNER = self
     auto_scheduler.auto_schedule(
         self.auto_task,
         tuning_options=auto_scheduler.TuningOptions(
             num_measure_trials=n_trial,
             runner=self.measure_ctx.runner,
             measure_callbacks=[]))
Exemplo n.º 3
0
 def tune(self, n_trial, **kwargs):
     global GLOBAL_TUNER
     GLOBAL_TUNER = self
     try:
         auto_scheduler.auto_schedule(
             self.auto_task,
             tuning_options=auto_scheduler.TuningOptions(
                 num_measure_trials=n_trial,
                 num_measures_per_round=self.task.n_parallel,
                 runner=self.measure_ctx.runner,
                 measure_callbacks=[]))
     except:
         import traceback
         traceback.print_exc()
         exit(1)
Exemplo n.º 4
0
def search_common(workload=matmul_auto_scheduler_test,
                  target="llvm",
                  search_policy='empty',
                  seed=random.randint(1, 1 << 30),
                  runner='local',
                  cost_model=auto_scheduler.RandomModel(),
                  num_measure_trials=2,
                  init_search_callbacks=None):
    print("Test %s schedule search with the default search policy" % (target))

    random.seed(seed)
    N = 128
    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create(target)
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == 'empty':
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == 'sketch':
            search_policy = auto_scheduler.SketchPolicy(
                task, init_search_callbacks=init_search_callbacks)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            runner=runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, workload_key, target)

        print("==== Python Code ====")
        print(dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()
Exemplo n.º 5
0
def resume_search(task, log_file):
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task,
        cost_model,
        init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)])
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
    sch, args = auto_scheduler.auto_schedule(task,
                                             search_policy,
                                             tuning_options=tune_option)
log_file = "conv2d.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)

######################################################################
# Run the search
# ^^^^^^^^^^^^^^
# Now we get all inputs ready. Pretty simple, isn't it?
# We can kick off the search and let the auto-scheduler do its magic.
# After some measurement trials, it will return the best schedule it found.

sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)

# Kill the process for measurement
del measure_ctx

######################################################################
# We can lower the schedule to see the IR after auto-scheduling.
# The auto-scheduler correctly performs optimizations including multi-level tiling,
# cooperative fetching, unrolling and operator fusion.

print(tvm.lower(sch, args, simple_mode=True))

######################################################################
# Check correctness and evaluate performance
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We build the binary and check its correctness and performance.
Exemplo n.º 7
0
func_name = "dense_fwd_n" + str(N) + "_ci" + str(CI) + "_co" + str(
    CO) + "_" + str(device)
log_file = func_name + ".json"

task = tvm.auto_scheduler.create_task(dense_fwd, (N, CI, CO, "float32"),
                                      target)

### search
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=num_search_trails,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)
sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
del measure_ctx

### load history
# inp, res = auto_scheduler.load_best(log_file, task.workload_key)
# sch, args = task.compute_dag.apply_steps_from_state(inp.state)

# build func
ctx = tvm.gpu()
func = tvm.build(sch, args, target, name=func_name)
# save result
obj_fname = func_name + ".o"
ptx_fname = func_name + ".ptx"
func.save(obj_fname)
func.imported_modules[0].save(ptx_fname)
Exemplo n.º 8
0
def test_layout_rewrite_correctness():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner="local",
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state,
                                                     layout_rewrite=False)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = ([
                2 + base,
                4 + base,
            ] + list(range(base + 2)) + [
                3 + base,
                5 + base,
            ])
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        np.testing.assert_allclose(np_args[0], np_args_ref[0])
        np.testing.assert_allclose(np_args[2], np_args_ref[2])
Exemplo n.º 9
0
def search_common(
        workload=matmul_auto_scheduler_test,
        target="llvm",
        search_policy="sketch",
        seed=0,
        runner="local",
        num_measure_trials=100,
        cost_model=auto_scheduler.RandomModel(),
        init_search_callbacks=None,
):
    print("Test search policy '%s' for '%s'" % (search_policy, target))

    random.seed(seed)
    N = 128
    target = tvm.target.Target(target)
    task = auto_scheduler.create_task(workload, (N, N, N), target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == "empty":
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == "sketch":
            search_policy = auto_scheduler.SketchPolicy(
                task,
                program_cost_model=cost_model,
                init_search_callbacks=init_search_callbacks)
        else:
            raise ValueError("Invalid policy: " + search_policy)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            num_measures_per_round=2,
            early_stopping=1,
            runner=runner,
            verbose=2,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, task.workload_key,
                                            target)

        print("==== Python Code ====")
        print(task.compute_dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = task.compute_dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()