def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask text_format: Union[str, Sequence[str]] The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ if not isinstance(text_format, str): for opt in text_format: new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) return new_sch trace = sch.trace assert trace is not None # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) if text_format == "json": json_obj = trace.as_json() Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: if remove_decisions: trace_1 = Trace(sch_1.trace.insts, {}) trace_2 = Trace(sch_2.trace.insts, {}) else: trace_1 = sch_1.trace trace_2 = sch_2.trace return str(trace_1) == str(trace_2)
def check_trace(spaces: List[Schedule], expected: List[List[str]]): expected_traces = {"\n".join(t) for t in expected} actual_traces = set() for space in spaces: trace = Trace(space.trace.insts, {}) trace = trace.simplified(remove_postproc=True) str_trace = "\n".join(str(trace).strip().splitlines()) actual_traces.add(str_trace) assert str_trace in expected_traces, "\n" + str_trace assert len(expected_traces) == len(actual_traces)
def _find_match_sketch_id( mod: IRModule, sketches: List[Schedule], expected_mod: IRModule, expected_decision: List[Tuple[str, List[int]]], *, debug_mask="all", ) -> Optional[int]: for sketch_id, sketch in enumerate(sketches): i = 0 new_decisions = {} for inst in sketch.trace.insts: if not inst.kind.name.startswith("Sample"): continue assert i < len(expected_decision) if inst.kind.name == expected_decision[i][0]: new_decisions[inst] = expected_decision[i][1] i += 1 if len(new_decisions) != len(expected_decision): continue sch = Schedule(mod, debug_mask=debug_mask) Trace( insts=sketch.trace.insts, decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) return sketch_id return None
def test_conv2d_winograd_cuda(): mod = conv2d_winograd_cuda mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access ) for sch_rule in context.sch_rules: sch_rule.initialize_with_tune_context(context) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch,) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], [ # data_pack [3, 3], [64, 2], 2, # inverse [3, 3], [2, 64], 2, # bgemm [1, 1, 1, 1, 6], [1, 1, 1, 3, 2], [3, 1, 1, 1, 3], [4, 2, 1, 4, 4], [32, 1, 4], 1, 1, # root anno 2, # conv2d 2, ], ) ) trace = Trace(sch.trace.insts, decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name return Trace( insts=[ _make_get_block(name="block", output=b0), _make_get_loops(input=b0, outputs=[l1, l2]), ], decisions={}, )
def _make_trace_2(b0): # pylint: disable=invalid-name return Trace( insts=[ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), ], decisions={}, )
def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name return Trace( insts=[ _make_get_block(name="B", output=b0), _make_get_loops(input=b0, outputs=[l1]), _make_split([l1, None, 32], [l2, l3]), ], decisions={}, )
def test_conv2d_winograd_cpu(): mod = conv2d_winograd_cpu mod = IRModule({"main": mod}) context = TuneContext( mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", sch_rules=DefaultLLVM._sch_rules(), # pylint: disable=protected-access ) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch, ) = post_order_apply.generate_design_space(mod) decisions = dict( zip( [ i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample") ], [ # data_pack [9, 1], [32, 4], # input_tile 4, # data_pad -2, # inverse [1, 9], [2, 64], # bgemm [1, 2, 3, 1], [1, 1, 1, 6], [1, 1, 1, 9], [2, 1, 16, 4], [16, 8], ], )) trace = Trace(sch.trace.insts[:-4], decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod expected = _get_mod() tvm.ir.assert_structural_equal(answer, expected)
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ # Step 1. Serialize the trace to JSON trace = sch.trace assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name if add_postproc: insts = [ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), _make_get_block(name="C", output=b1), _make_enter_postproc(), _make_compute_inline(input=b1), ] else: insts = [ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), _make_get_block(name="C", output=b1), ] return Trace(insts=insts, decisions={})
def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])
def test_trace_construct_pop_2(): trace = Trace([], {}) assert str(trace) == "" assert trace.pop() is None assert str(trace) == ""
def apply(self, trace: Trace, _) -> Optional[Trace]: return Trace(trace.insts, {})
def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: trace_1 = Trace(sch_1.trace.insts, {}) trace_2 = Trace(sch_2.trace.insts, {}) return str(trace_1) == str(trace_2)