def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask text_format: Union[str, Sequence[str]] The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ if not isinstance(text_format, str): for opt in text_format: new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) return new_sch trace = sch.trace assert trace is not None # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) if text_format == "json": json_obj = trace.as_json() Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ # Step 1. Serialize the trace to JSON trace = sch.trace assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])