def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() _, bufs = dag.apply_steps_from_state(s, layout_rewrite=False) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 _, bufs = dag.apply_steps_from_state(s, layout_rewrite=True) assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512
def test_measure_local_builder_runner(): if not tvm.runtime.enabled("llvm"): return dag, s0 = get_tiled_matmul() tgt = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", tgt) minp = auto_scheduler.MeasureInput(task, s0) local_builder = auto_scheduler.LocalBuilder() local_runner = auto_scheduler.LocalRunner(timeout=60) bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = local_runner.run([minp], bress) assert mress[0].error_no == 0
def test_measure_local_builder_runner(enable_cpu_cache_flush=False): if not tvm.testing.device_enabled("llvm"): return dag, s0 = get_tiled_matmul() tgt = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(dag, "test", tgt) minp = auto_scheduler.MeasureInput(task, s0) local_builder = auto_scheduler.LocalBuilder() local_runner = auto_scheduler.LocalRunner( timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush) bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = local_runner.run([minp], bress) assert mress[0].error_no == 0
def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() _, bufs = dag.apply_steps_from_state(s) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED ) assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512 _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.LayoutRewriteOption.INSERT_TRANSFORM_STAGE ) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512
def test_apply_steps_with_layout_rewrite(): dag, s = get_tiled_matmul() _, bufs = dag.apply_steps_from_state(s) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512 _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. RewriteForPreTransformed) assert bufs[1].shape[0] == 4 assert bufs[1].shape[1] == 8 assert bufs[1].shape[2] == 4 assert bufs[1].shape[3] == 4 assert bufs[1].shape[4] == 512 _, bufs = dag.apply_steps_from_state( s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG. InsertTransformStage) assert bufs[1].shape[0] == 512 assert bufs[1].shape[1] == 512
def test_record(): dag, s = get_tiled_matmul() if not tvm.runtime.enabled("llvm"): return target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) inp = auto_scheduler.measure.MeasureInput(task, s) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) with tempfile.NamedTemporaryFile() as fp: auto_scheduler.save_records(fp.name, [inp], [res]) log_reader = auto_scheduler.RecordReader(fp.name) inputs, results = log_reader.read_lines() assert len(inputs) == 1 s1 = dag.infer_bound_from_state(s) s2 = dag.infer_bound_from_state(inputs[0].state) assert s1 == s2 assert not (s1 == dag.get_init_state())
def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s)
def test_apply_steps(): dag, s = get_tiled_matmul() dag.print_python_code_from_state(s) sch, tensors = dag.apply_steps_from_state(s) tvm.lower(sch, tensors, simple_mode=True)
def test_estimate_flop(): dag, s = get_tiled_matmul() assert abs(dag.flop_ct - 2 * 512**3) < 0.5