Ejemplo n.º 1
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
    text_format: Union[str, Sequence[str]] = ["python", "json"],
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    text_format: Union[str, Sequence[str]]
        The text format or formats whose round-trip behavior should be
        validated.  If a single string, validate round-trips through
    """
    if not isinstance(text_format, str):
        for opt in text_format:
            new_sch = verify_trace_roundtrip(sch,
                                             mod,
                                             debug_mask=debug_mask,
                                             text_format=opt)
        return new_sch

    trace = sch.trace
    assert trace is not None

    # Step 1. Perform a round-trip through the text-format
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    if text_format == "json":
        json_obj = trace.as_json()
        Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    elif text_format == "python":
        py_trace = "\n".join(trace.as_python())
        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: disable=exec-used
    else:
        assert text_format in ("json",
                               "python"), f"Unknown text format: {text_format}"

    # Step 2. Verify that the round-trip produced the same scheduling
    assert structural_equal(new_sch.mod, sch.mod)

    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr

    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Ejemplo n.º 2
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    """
    # Step 1. Serialize the trace to JSON
    trace = sch.trace
    assert trace is not None
    json_obj = trace.as_json()
    # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    assert structural_equal(new_sch.mod, sch.mod)
    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr
    # Step 4. Return the new schedule in case it could be useful
    return new_sch
def test_apply_json_to_schedule_1():
    trace = _make_trace_2(BlockRV())
    json_obj = trace.as_json()
    sch = tir.Schedule(elementwise, debug_mask="all")
    Trace.apply_json_to_schedule(json_obj, sch)
    tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"])