Esempio n. 1
0
def run_tuning():
    print("Begin tuning...")
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=
        200,  # change this to 20000 to achieve the best performance
        runner=auto_scheduler.LocalRunner(repeat=10,
                                          enable_cpu_cache_flush=True),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    if use_sparse:
        from tvm.topi.sparse.utils import sparse_sketch_rules

        search_policy = [
            auto_scheduler.SketchPolicy(
                task,
                program_cost_model=auto_scheduler.XGBModel(),
                init_search_callbacks=sparse_sketch_rules(),
            ) for task in tasks
        ]

        tuner.tune(tune_option, search_policy=search_policy)
    else:
        tuner.tune(tune_option)
def generate_sketches(workload_func, args, target, print_for_debug=False):
    workload_key = auto_scheduler.make_workload_key(workload_func, args)
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.create(target))
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    return policy.generate_sketches(print_for_debug)
Esempio n. 3
0
def search_common(workload=matmul_auto_scheduler_test,
                  target="llvm",
                  search_policy='empty',
                  seed=random.randint(1, 1 << 30),
                  runner='local',
                  cost_model=auto_scheduler.RandomModel(),
                  num_measure_trials=2,
                  init_search_callbacks=None):
    print("Test %s schedule search with the default search policy" % (target))

    random.seed(seed)
    N = 128
    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create(target)
    task = auto_scheduler.SearchTask(dag, workload_key, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == 'empty':
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == 'sketch':
            search_policy = auto_scheduler.SketchPolicy(
                task, init_search_callbacks=init_search_callbacks)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            runner=runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, workload_key, target)

        print("==== Python Code ====")
        print(dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()
Esempio n. 4
0
def test_correctness_layout_rewrite_insert_transform_stage():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
            InsertTransformStage)

        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(),
                                    args_ref[0].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[1].asnumpy(),
                                    args_ref[1].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(),
                                    args_ref[2].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        del measure_ctx
Esempio n. 5
0
def generate_sketches(
    workload_func, args, target, print_for_debug=False, init_search_callbacks=None
):
    task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target)
    policy = auto_scheduler.SketchPolicy(
        task, verbose=0, init_search_callbacks=init_search_callbacks
    )
    return policy.generate_sketches(print_for_debug)
def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
    )
    task.tune(tune_option, search_policy=search_policy)
Esempio n. 7
0
def get_sample_records(number):
    """Generate a list of random MeasureInput and MeasureResult pairs"""
    N = 128
    task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target="llvm")
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    states = policy.sample_initial_population()[:number]

    inputs = [auto_scheduler.MeasureInput(task, s) for s in states]
    results = [
        auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0)
        for _ in range(len(inputs))
    ]

    return task, inputs, results
def get_sample_records(number):
    """Generate random a list of random MeasureInput and MeasureResult pairs"""
    N = 128
    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
    dag = auto_scheduler.ComputeDAG(workload_key)
    target = tvm.target.create('llvm')
    task = auto_scheduler.SearchTask(dag, workload_key, target)
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    states = policy.sample_initial_population(number)

    inputs = [auto_scheduler.MeasureInput(task, s) for s in states]
    results = [auto_scheduler.MeasureResult([np.random.uniform(0.5, 1.0)], 0, "", 0.1, 0)
               for _ in range(len(inputs))]

    return task, dag, inputs, results
Esempio n. 9
0
def generate_sketches(
    workload_func, args, target, print_for_debug=False, init_search_callbacks=None
):
    # NOTE: test_cpu_matmul_sketch and test_cpu_max_pool2d_sketch assume 4 cores to trigger all
    # possible sketch generations.
    task = auto_scheduler.SearchTask(
        func=workload_func,
        args=args,
        target=target,
        hardware_params=auto_scheduler.HardwareParams(num_cores=4, target=target),
    )
    policy = auto_scheduler.SketchPolicy(
        task, verbose=0, init_search_callbacks=init_search_callbacks
    )
    return policy.generate_sketches(print_for_debug)
Esempio n. 10
0
def resume_search(task, logfile_name):
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(logfile_name)
    search_policy = auto_scheduler.SketchPolicy(
        task,
        cost_model,
        init_search_callbacks=[
            auto_scheduler.PreloadMeasuredStates(logfile_name)
        ])
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5,
        measure_callbacks=[auto_scheduler.RecordToFile(logfile_name)])
    sch, args = auto_scheduler.auto_schedule(task,
                                             search_policy,
                                             tuning_options=tune_option)
Esempio n. 11
0
def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5,
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )
    task.tune(tune_option, search_policy=search_policy)

    # Kill the measurement process
    del measure_ctx
Esempio n. 12
0
def test_mutate_parallel():
    """
    The test case initializes evo search with a batch of "bad" states and check whether
    the search algorithm can find "good" states by mutating the "bad" states.
    """

    class MockCostModel(PythonBasedModel):
        @staticmethod
        def is_good_state(state):
            for line in str(state).split("\n"):
                if (
                    line.find("parallel i.0@ (0") != -1
                    or line.find("parallel [email protected]@ (0") != -1
                    or line.find("parallel [email protected]@i.1@ (0") != -1
                ):
                    return True
            return False

        def predict(self, task, states):
            scores = []
            for state in states:
                scores.append(1 if self.is_good_state(state) else 0)
            return scores

    task = auto_scheduler.SearchTask(
        func=matmul_auto_scheduler_test, args=(1024, 1024, 1024), target="llvm"
    )
    policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)

    found = False
    retry_ct = 0
    while retry_ct < 10 and not found:
        states = policy.sample_initial_population()[:100]
        bad_states = []
        for state in states:
            if not MockCostModel.is_good_state(state):
                bad_states.append(state)

        new_states = policy.evolutionary_search(bad_states, 50)
        for state in new_states:
            if MockCostModel.is_good_state(state):
                found = True
                break
        retry_ct += 1

    assert found
def test_mutate_tile_size():
    """
    The test case initializes evo search with a batch of "bad" states and check whether
    the search algorithm can find "good" states by mutating the "bad" states.

    This unit test has been tested with 1,000 runs with no failures, meaning that
    the failure rate is less than 0.1%.
    """
    class MockCostModel(PythonBasedModel):
        """A mock cost model that rates 1 only for the states with tile_k=2."""
        @staticmethod
        def is_good_state(state):
            for line in str(state).split("\n"):
                if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                    return True
            return False

        def predict(self, task, states):
            scores = []
            for state in states:
                scores.append(1 if self.is_good_state(state) else 0)
            return scores

    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test,
                                                    (10, 10, 4))
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.Target("llvm"))
    policy = auto_scheduler.SketchPolicy(task,
                                         program_cost_model=MockCostModel(),
                                         verbose=0)
    states = policy.sample_initial_population()[:50]

    bad_states = []
    for state in states:
        if not MockCostModel.is_good_state(state):
            bad_states.append(state)

    new_states = policy.evolutionary_search(bad_states, 50)
    found = False
    for state in new_states:
        if MockCostModel.is_good_state(state):
            found = True
            break
    assert found
def test_evo_search():
    """Test evolutionary search. Since we cannot mock random number generator,
    we mocked the cost model to manually guide the evo search. If evo search works
    as expected, it should find the target state after a sufficient number of iterations.
    This unit test has been tested with 1,000 runs with no failures, meaning that
    the failure rate is less than 0.1%.
    """
    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test,
                                                    (10, 10, 4))
    dag = auto_scheduler.ComputeDAG(workload_key)
    task = auto_scheduler.SearchTask(dag, workload_key,
                                     tvm.target.Target("llvm"))
    policy = auto_scheduler.SketchPolicy(task,
                                         schedule_cost_model=MockCostModel(),
                                         verbose=0)
    states = policy.sample_initial_population(50)
    pruned_states = []
    for state in states:
        found = False
        for line in str(state).split("\n"):
            # Remove all tile_k=2 states and expect evo search will fine them.
            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                found = True
                break
        if not found:
            pruned_states.append(state)

    new_states = policy.evolutionary_search(pruned_states, 50)
    found = False
    for state in new_states:
        for line in str(state).split("\n"):
            # Check if evo search found at least one state with tile_k=2.
            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
                found = True
                break
        if found:
            break
    assert found
Esempio n. 15
0
#   and do more analyses later.
# * see :any:`auto_scheduler.TuningOptions` for more parameters
# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch
#   rule as a `init_search_callbacks`.

log_file = "sparse_dense.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

search_policy = auto_scheduler.SketchPolicy(
    task,
    program_cost_model=auto_scheduler.XGBModel(),
    init_search_callbacks=[
        auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func,
                                               "SparseDense")
    ],
)

######################################################################
# Run the search
# ^^^^^^^^^^^^^^
# Now we get all inputs ready.
# We can kick off the search and let the auto-scheduler do its magic.
# After some measurement trials, we can load the best schedule from the log
# file and apply it.

# Run auto-tuning (search)
# Notice: We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
print(task.compute_dag.print_python_code_from_state(inp.state))

# Rebuild the binary. This shows how you can apply the best schedule from a
# log file without reruning the search again.
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args, target)

######################################################################
# A more complicated example is to resume the search.
# In this case, we need to create the search policy and cost model by ourselves
# and resume the status of search policy and cost model with the log file.
# In the example below we resume the status and do more 5 trials.

cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
    task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=5,
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)

# Kill the measurement process
del measure_ctx
Esempio n. 17
0
def test_layout_rewrite_correctness():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner="local",
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state,
                                                     layout_rewrite=False)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = ([
                2 + base,
                4 + base,
            ] + list(range(base + 2)) + [
                3 + base,
                5 + base,
            ])
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        np.testing.assert_allclose(np_args[0], np_args_ref[0])
        np.testing.assert_allclose(np_args[2], np_args_ref[2])
Esempio n. 18
0
def search_common(
        workload=matmul_auto_scheduler_test,
        target="llvm",
        search_policy="sketch",
        seed=0,
        runner="local",
        num_measure_trials=100,
        cost_model=auto_scheduler.RandomModel(),
        init_search_callbacks=None,
):
    print("Test search policy '%s' for '%s'" % (search_policy, target))

    random.seed(seed)
    N = 128
    target = tvm.target.Target(target)
    task = auto_scheduler.create_task(workload, (N, N, N), target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == "empty":
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == "sketch":
            search_policy = auto_scheduler.SketchPolicy(
                task,
                program_cost_model=cost_model,
                init_search_callbacks=init_search_callbacks)
        else:
            raise ValueError("Invalid policy: " + search_policy)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            num_measures_per_round=2,
            early_stopping=1,
            runner=runner,
            verbose=2,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        sch, args = auto_scheduler.auto_schedule(task, search_policy,
                                                 tuning_options)
        inp, res = auto_scheduler.load_best(log_file, task.workload_key,
                                            target)

        print("==== Python Code ====")
        print(task.compute_dag.print_python_code_from_state(inp.state))

        try:
            print("==== Lowered Stmt ====")
            print(tvm.lower(sch, args, simple_mode=True))
            mod = tvm.build(sch, args, target)

            ctx = tvm.context(str(target), 0)
            dtype = task.compute_dag.tensors[0].dtype
            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
            mod(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(),
                                        np.dot(a.asnumpy(), b.asnumpy()),
                                        rtol=1e-5)
            print("==== Verification passed ====")
        except Exception:
            raise Exception("Error encountered with seed: %d" % (seed))
    print()
Esempio n. 19
0
# * In addition, we use :code:`RecordToFile` to dump measurement records into a file `conv2d.json`.
#   The measurement records can be used to query the history best, resume the search,
#   and do more analyses later.
# * see :any:`auto_scheduler.TuningOptions`,
#   :any:`auto_scheduler.LocalRPCMeasureContext` for more parameters.

log_file = "conv2d_x86.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

policy = auto_scheduler.SketchPolicy(task, verbose=0)
print(policy.generate_sketches(True))

######################################################################
# Run the search
# ^^^^^^^^^^^^^^
# Now we get all inputs ready. Pretty simple, isn't it?
# We can kick off the search and let the auto-scheduler do its magic.
# After some measurement trials, we can load the best schedule from the log
# file and apply it.

# Run auto-tuning (search)
#task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
Esempio n. 20
0
def generate_sketches(workload_func, args, target, print_for_debug=False):
    task = auto_scheduler.create_task(workload_func, args, tvm.target.Target(target))
    policy = auto_scheduler.SketchPolicy(task, verbose=0)
    return policy.generate_sketches(print_for_debug)
Esempio n. 21
0
def search_common(
        task=None,
        target="llvm",
        search_policy="sketch",
        runner="local",
        num_measure_trials=100,
        cost_model=auto_scheduler.RandomModel(),
        init_search_callbacks=None,
):
    if task is None:
        task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test,
                                         args=(64, 64, 64),
                                         target=target)
    target = task.target

    print("Test search policy '%s' for '%s'" % (search_policy, target))

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        init_search_callbacks = init_search_callbacks or []
        init_search_callbacks.append(
            auto_scheduler.PreloadMeasuredStates(log_file))

        if search_policy == "empty":
            search_policy = auto_scheduler.EmptyPolicy(task)
        elif search_policy == "sketch":
            search_policy = auto_scheduler.SketchPolicy(
                task,
                program_cost_model=cost_model,
                init_search_callbacks=init_search_callbacks)
        else:
            raise ValueError("Invalid policy: " + search_policy)

        # Tune
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=num_measure_trials,
            num_measures_per_round=2,
            early_stopping=1,
            runner=runner,
            measure_callbacks=[
                auto_scheduler.RecordToFile(log_file),
                CustomMeasureCallback()
            ],
        )
        task.tune(tuning_options=tuning_options, search_policy=search_policy)

        # Compile with the best schedule
        sch, args = task.apply_best(log_file)
        mod = tvm.build(sch, args, target)

        # Compile with naive schedule for correctness check
        sch, args = task.compute_dag.apply_steps_from_state(
            task.compute_dag.init_state)
        mod_ref = tvm.build(sch, args, "llvm")

        ctx = tvm.device(str(target), 0)
        np_arrays = [
            np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
            for x in args
        ]

        tvm_arrays = [tvm.nd.array(x, ctx) for x in np_arrays]
        mod(*tvm_arrays)
        actual = [x.numpy() for x in tvm_arrays]

        tvm_arrays = [tvm.nd.array(x) for x in np_arrays]
        mod_ref(*tvm_arrays)
        expected = [x.numpy() for x in tvm_arrays]

        for x, y in zip(actual, expected):
            tvm.testing.assert_allclose(x, y, rtol=1e-5)
Esempio n. 22
0
def test_correctness_layout_rewrite_rewrite_for_preTransformed():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=2,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        task.tune(tuning_options, search_policy=search_policy)
        inp, _ = auto_scheduler.load_best_record(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
        )
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = (
                [
                    2 + base,
                    4 + base,
                ]
                + list(range(base + 2))
                + [
                    3 + base,
                    5 + base,
                ]
            )
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3)
        del measure_ctx