def test_correctness_layout_rewrite_insert_transform_stage(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner=measure_ctx.runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state( inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. InsertTransformStage) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3) del measure_ctx
def test_layout_rewrite_correctness(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner="local", verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False) np_args = [ np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs ] np_args_ref = [np.array(x) for x in np_args] weight = np_args_ref[1] # infer shape for the rewritten layout if len(weight.shape) >= 6: # For cpu tile structure SSRSRS base = len(weight.shape) - 6 red_dim = weight.shape[2 + base] * weight.shape[4 + base] out_dim = weight.shape[3 + base] * weight.shape[5 + base] for i in range(base + 2): out_dim *= weight.shape[i] new_order = ([ 2 + base, 4 + base, ] + list(range(base + 2)) + [ 3 + base, 5 + base, ]) np_args_ref[1] = np_args_ref[1].transpose(new_order) np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim)) func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() np.testing.assert_allclose(np_args[0], np_args_ref[0]) np.testing.assert_allclose(np_args[2], np_args_ref[2])
def test_correctness_layout_rewrite_rewrite_for_preTransformed(): N = 128 target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target) dag = task.compute_dag with tempfile.NamedTemporaryFile() as fp: log_file = fp.name search_policy = auto_scheduler.SketchPolicy(task) measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( num_measure_trials=2, runner=measure_ctx.runner, verbose=2, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task.tune(tuning_options, search_policy=search_policy) inp, _ = auto_scheduler.load_best_record(log_file, task.workload_key, target) s, bufs = dag.apply_steps_from_state( inp.state, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED ) s_ref, bufs_ref = dag.apply_steps_from_state(inp.state) np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] np_args_ref = [np.array(x) for x in np_args] weight = np_args_ref[1] # infer shape for the rewritten layout if len(weight.shape) >= 6: # For cpu tile structure SSRSRS base = len(weight.shape) - 6 red_dim = weight.shape[2 + base] * weight.shape[4 + base] out_dim = weight.shape[3 + base] * weight.shape[5 + base] for i in range(base + 2): out_dim *= weight.shape[i] new_order = ( [ 2 + base, 4 + base, ] + list(range(base + 2)) + [ 3 + base, 5 + base, ] ) np_args_ref[1] = np_args_ref[1].transpose(new_order) np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim)) func = tvm.build(s, bufs, target=target) func_ref = tvm.build(s_ref, bufs_ref, target=target) ctx = tvm.context(str(target)) ctx_ref = tvm.cpu() args = [tvm.nd.array(x, ctx=ctx) for x in np_args] args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref] ctx.sync() func(*args) func_ref(*args_ref) ctx.sync() tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3) tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3) del measure_ctx